From b5b1441478cd89f6c61f81db56cf8435cd80d060 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 23 Oct 2023 13:07:34 +0000 Subject: [PATCH 01/12] v1 --- src/transformers/modeling_utils.py | 19 ++++++++++++++++--- .../models/align/modeling_align.py | 7 ++++--- .../models/altclip/modeling_altclip.py | 12 +++++++----- .../modeling_audio_spectrogram_transformer.py | 7 ++++--- .../models/autoformer/modeling_autoformer.py | 9 +++++---- src/transformers/models/bark/modeling_bark.py | 7 ++++--- src/transformers/models/bart/modeling_bart.py | 9 +++++---- src/transformers/models/beit/modeling_beit.py | 7 ++++--- src/transformers/models/bert/modeling_bert.py | 7 ++++--- .../modeling_bert_generation.py | 7 ++++--- .../models/big_bird/modeling_big_bird.py | 7 ++++--- .../modeling_bigbird_pegasus.py | 9 +++++---- .../models/biogpt/modeling_biogpt.py | 7 ++++--- src/transformers/models/bit/modeling_bit.py | 5 +++-- .../models/blenderbot/modeling_blenderbot.py | 9 +++++---- .../modeling_blenderbot_small.py | 9 +++++---- src/transformers/models/blip/modeling_blip.py | 7 ++++--- .../models/blip/modeling_blip_text.py | 2 +- .../models/blip_2/modeling_blip_2.py | 9 +++++---- .../models/bloom/modeling_bloom.py | 7 ++++--- .../bridgetower/modeling_bridgetower.py | 2 +- src/transformers/models/bros/modeling_bros.py | 2 +- .../models/camembert/modeling_camembert.py | 7 ++++--- .../models/canine/modeling_canine.py | 7 ++++--- .../chinese_clip/modeling_chinese_clip.py | 9 +++++---- src/transformers/models/clap/modeling_clap.py | 9 +++++---- src/transformers/models/clip/modeling_clip.py | 7 ++++--- .../models/clipseg/modeling_clipseg.py | 7 ++++--- .../models/codegen/modeling_codegen.py | 7 ++++--- .../modeling_conditional_detr.py | 7 ++++--- .../models/convbert/modeling_convbert.py | 7 ++++--- .../models/convnext/modeling_convnext.py | 5 +++-- .../models/convnextv2/modeling_convnextv2.py | 5 +++-- .../models/cpmant/modeling_cpmant.py | 5 +++-- .../data2vec/modeling_data2vec_audio.py | 9 +++++---- .../models/data2vec/modeling_data2vec_text.py | 7 ++++--- .../data2vec/modeling_data2vec_vision.py | 7 ++++--- .../models/deberta/modeling_deberta.py | 7 ++++--- .../models/deberta_v2/modeling_deberta_v2.py | 7 ++++--- .../modeling_decision_transformer.py | 7 ++++--- .../modeling_deformable_detr.py | 7 ++++--- src/transformers/models/deit/modeling_deit.py | 7 ++++--- .../models/deprecated/mctct/modeling_mctct.py | 7 ++++--- .../open_llama/modeling_open_llama.py | 7 ++++--- .../modeling_trajectory_transformer.py | 7 ++++--- .../models/deprecated/van/modeling_van.py | 5 +++-- src/transformers/models/deta/modeling_deta.py | 7 ++++--- src/transformers/models/detr/modeling_detr.py | 7 ++++--- .../models/dinat/modeling_dinat.py | 2 +- .../models/dinov2/modeling_dinov2.py | 7 ++++--- .../models/distilbert/modeling_distilbert.py | 7 ++++--- .../models/donut/modeling_donut_swin.py | 7 ++++--- src/transformers/models/dpr/modeling_dpr.py | 5 +++-- src/transformers/models/dpt/modeling_dpt.py | 7 ++++--- .../efficientnet/modeling_efficientnet.py | 5 +++-- .../models/electra/modeling_electra.py | 7 ++++--- .../models/encodec/modeling_encodec.py | 5 +++-- .../modeling_encoder_decoder.py | 6 +++--- .../models/ernie/modeling_ernie.py | 7 ++++--- .../models/ernie_m/modeling_ernie_m.py | 5 +++-- src/transformers/models/esm/modeling_esm.py | 7 ++++--- .../models/falcon/modeling_falcon.py | 7 ++++--- .../models/flava/modeling_flava.py | 7 ++++--- src/transformers/models/fnet/modeling_fnet.py | 7 ++++--- .../models/focalnet/modeling_focalnet.py | 7 ++++--- src/transformers/models/fuyu/modeling_fuyu.py | 5 +++-- src/transformers/models/git/modeling_git.py | 9 +++++---- src/transformers/models/gpt2/modeling_gpt2.py | 7 ++++--- .../gpt_bigcode/modeling_gpt_bigcode.py | 7 ++++--- .../models/gpt_neo/modeling_gpt_neo.py | 7 ++++--- .../models/gpt_neox/modeling_gpt_neox.py | 7 ++++--- .../modeling_gpt_neox_japanese.py | 5 +++-- src/transformers/models/gptj/modeling_gptj.py | 7 ++++--- .../modeling_gptsan_japanese.py | 5 +++-- .../models/graphormer/modeling_graphormer.py | 5 +++-- .../models/groupvit/modeling_groupvit.py | 7 ++++--- .../models/hubert/modeling_hubert.py | 11 ++++++----- .../models/idefics/modeling_idefics.py | 7 ++++--- src/transformers/models/idefics/vision.py | 2 +- .../models/imagegpt/modeling_imagegpt.py | 7 ++++--- .../models/informer/modeling_informer.py | 11 ++++++----- .../instructblip/modeling_instructblip.py | 9 +++++---- .../models/layoutlm/modeling_layoutlm.py | 7 ++++--- .../models/layoutlmv2/modeling_layoutlmv2.py | 7 ++++--- .../models/layoutlmv3/modeling_layoutlmv3.py | 2 +- src/transformers/models/led/modeling_led.py | 9 +++++---- .../models/levit/modeling_levit.py | 5 +++-- src/transformers/models/lilt/modeling_lilt.py | 7 ++++--- .../models/llama/modeling_llama.py | 7 ++++--- .../models/longformer/modeling_longformer.py | 7 ++++--- .../models/longt5/modeling_longt5.py | 5 +++-- src/transformers/models/luke/modeling_luke.py | 7 ++++--- .../models/m2m_100/modeling_m2m_100.py | 9 +++++---- .../models/marian/modeling_marian.py | 9 +++++---- .../models/markuplm/modeling_markuplm.py | 2 +- .../mask2former/modeling_mask2former.py | 2 +- .../models/maskformer/modeling_maskformer.py | 10 ++++++---- .../maskformer/modeling_maskformer_swin.py | 7 ++++--- .../models/mbart/modeling_mbart.py | 9 +++++---- .../megatron_bert/modeling_megatron_bert.py | 7 ++++--- .../models/mgp_str/modeling_mgp_str.py | 5 +++-- .../models/mistral/modeling_mistral.py | 7 ++++--- .../models/mobilevit/modeling_mobilevit.py | 7 ++++--- .../mobilevitv2/modeling_mobilevitv2.py | 7 ++++--- src/transformers/models/mpt/modeling_mpt.py | 7 ++++--- src/transformers/models/mra/modeling_mra.py | 7 ++++--- src/transformers/models/mt5/modeling_mt5.py | 5 +++-- .../models/musicgen/modeling_musicgen.py | 11 ++++++----- src/transformers/models/mvp/modeling_mvp.py | 9 +++++---- src/transformers/models/nat/modeling_nat.py | 2 +- .../models/nezha/modeling_nezha.py | 7 ++++--- .../models/nllb_moe/modeling_nllb_moe.py | 7 ++++--- .../nystromformer/modeling_nystromformer.py | 7 ++++--- .../models/oneformer/modeling_oneformer.py | 2 +- src/transformers/models/opt/modeling_opt.py | 7 ++++--- .../models/owlv2/modeling_owlv2.py | 8 +++++--- .../models/owlvit/modeling_owlvit.py | 7 ++++--- .../models/pegasus/modeling_pegasus.py | 9 +++++---- .../models/pegasus_x/modeling_pegasus_x.py | 9 +++++---- .../models/persimmon/modeling_persimmon.py | 7 ++++--- .../models/pix2struct/modeling_pix2struct.py | 12 +++++++----- .../models/plbart/modeling_plbart.py | 9 +++++---- .../models/poolformer/modeling_poolformer.py | 5 +++-- .../models/pop2piano/modeling_pop2piano.py | 5 +++-- .../models/prophetnet/modeling_prophetnet.py | 9 +++++---- src/transformers/models/pvt/modeling_pvt.py | 5 +++-- .../models/qdqbert/modeling_qdqbert.py | 7 ++++--- .../models/realm/modeling_realm.py | 2 +- .../models/regnet/modeling_regnet.py | 5 +++-- .../models/rembert/modeling_rembert.py | 7 ++++--- .../models/resnet/modeling_resnet.py | 5 +++-- .../models/roberta/modeling_roberta.py | 7 ++++--- .../modeling_roberta_prelayernorm.py | 7 ++++--- .../models/roc_bert/modeling_roc_bert.py | 7 ++++--- .../models/roformer/modeling_roformer.py | 7 ++++--- src/transformers/models/rwkv/modeling_rwkv.py | 7 ++++--- src/transformers/models/sam/modeling_sam.py | 2 +- src/transformers/models/sew/modeling_sew.py | 9 +++++---- .../models/sew_d/modeling_sew_d.py | 9 +++++---- .../modeling_speech_encoder_decoder.py | 6 +++--- .../speech_to_text/modeling_speech_to_text.py | 9 +++++---- .../modeling_speech_to_text_2.py | 7 ++++--- .../models/speecht5/modeling_speecht5.py | 11 ++++++----- .../models/splinter/modeling_splinter.py | 7 ++++--- .../swiftformer/modeling_swiftformer.py | 5 +++-- src/transformers/models/swin/modeling_swin.py | 7 ++++--- .../models/swin/modeling_tf_swin.py | 5 ----- .../models/swin2sr/modeling_swin2sr.py | 7 ++++--- .../models/swinv2/modeling_swinv2.py | 7 ++++--- .../modeling_switch_transformers.py | 5 +++-- src/transformers/models/t5/modeling_t5.py | 5 +++-- .../modeling_table_transformer.py | 7 ++++--- .../models/tapas/modeling_tapas.py | 7 ++++--- .../modeling_time_series_transformer.py | 9 +++++---- .../timesformer/modeling_timesformer.py | 7 ++++--- .../models/trocr/modeling_trocr.py | 7 ++++--- src/transformers/models/tvlt/modeling_tvlt.py | 9 +++++---- src/transformers/models/umt5/modeling_umt5.py | 5 +++-- .../models/unispeech/modeling_unispeech.py | 11 ++++++----- .../unispeech_sat/modeling_unispeech_sat.py | 11 ++++++----- .../models/upernet/modeling_upernet.py | 5 +++-- .../models/videomae/modeling_videomae.py | 9 +++++---- src/transformers/models/vilt/modeling_vilt.py | 7 ++++--- .../modeling_vision_encoder_decoder.py | 6 +++--- .../visual_bert/modeling_visual_bert.py | 7 ++++--- src/transformers/models/vit/modeling_vit.py | 7 ++++--- .../models/vit_hybrid/modeling_vit_hybrid.py | 7 ++++--- .../models/vit_mae/modeling_vit_mae.py | 9 +++++---- .../models/vit_msn/modeling_vit_msn.py | 7 ++++--- .../models/vitdet/modeling_vitdet.py | 7 ++++--- .../models/vitmatte/modeling_vitmatte.py | 5 +++-- src/transformers/models/vits/modeling_vits.py | 7 ++++--- .../models/vivit/modeling_vivit.py | 7 ++++--- .../models/wav2vec2/modeling_wav2vec2.py | 11 ++++++----- .../modeling_wav2vec2_conformer.py | 9 +++++---- .../models/wavlm/modeling_wavlm.py | 11 ++++++----- .../models/whisper/modeling_whisper.py | 9 +++++---- .../models/x_clip/modeling_x_clip.py | 9 +++++---- src/transformers/models/xglm/modeling_xglm.py | 7 ++++--- .../xlm_prophetnet/modeling_xlm_prophetnet.py | 9 +++++---- .../xlm_roberta/modeling_xlm_roberta.py | 7 ++++--- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 2 +- src/transformers/models/xmod/modeling_xmod.py | 7 ++++--- .../models/yolos/modeling_yolos.py | 7 ++++--- src/transformers/models/yoso/modeling_yoso.py | 7 ++++--- ...ng_{{cookiecutter.lowercase_modelname}}.py | 16 +++++++++------- tests/test_modeling_common.py | 7 +++++++ 187 files changed, 750 insertions(+), 562 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0317695f2096de..4da8bfcc90439b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections +import functools import gc import importlib.metadata import inspect @@ -1819,16 +1820,28 @@ def prune_heads(self, heads_to_prune: Dict[int, List[int]]): self.base_model._prune_heads(heads_to_prune) - def gradient_checkpointing_enable(self): + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): """ Activates gradient checkpointing for the current model. Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint activations". + + Args: + gradient_checkpointing_kwargs (dict, *optional*): + Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. """ if not self.supports_gradient_checkpointing: raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") - self.apply(partial(self._set_gradient_checkpointing, value=True)) + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {} + + gradient_checkpointing_func = functools.partial( + torch.utils.checkpoint.checkpoint, **gradient_checkpointing_kwargs + ) + + self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=gradient_checkpointing_func)) if getattr(self, "_hf_peft_config_loaded", False): # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True @@ -1845,7 +1858,7 @@ def gradient_checkpointing_disable(self): activations". """ if self.supports_gradient_checkpointing: - self.apply(partial(self._set_gradient_checkpointing, value=False)) + self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=None)) if getattr(self, "_hf_peft_config_loaded", False): self.disable_input_require_grads() diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index 6cbf01a3432ccb..bad7db0150c85b 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -1102,7 +1102,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -1197,9 +1197,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (AlignTextModel, AlignVisionModel)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index c4e32de55d9c03..56b9657aecb700 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -653,7 +653,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -967,7 +967,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1089,11 +1089,13 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, AltCLIPEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None if isinstance(module, AltRobertaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None # Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer with CLIPVisionTransformer->AltCLIPVisionTransformer,CLIPVisionConfig->AltCLIPVisionConfig,CLIPVisionEmbeddings->AltCLIPVisionEmbeddings,CLIPEncoder->AltCLIPEncoder,CLIP_VISION_INPUTS_DOCSTRING->ALTCLIP_VISION_INPUTS_DOCSTRING diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index 28969f50b67291..2a895dc073ba71 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -343,7 +343,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -395,9 +395,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.weight.data.fill_(1.0) # Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._set_gradient_checkpointing with ViT->AST - def _set_gradient_checkpointing(self, module: ASTEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: ASTEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, ASTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 96298c77a344e7..278811d23d9a34 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -946,9 +946,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (AutoformerDecoder, AutoformerEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None AUTOFORMER_START_DOCSTRING = r""" @@ -1214,7 +1215,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1433,7 +1434,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 649719e0eefa5d..8ffb22fd3e5d9f 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -313,9 +313,10 @@ def device(self) -> torch.device: return get_parameter_device(self) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BarkCausalModel) or isinstance(module, BarkFineModel) or isinstance(module, BarkModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BARK_MODEL_START_DOCSTRING = """ @@ -645,7 +646,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 9e7763ca23d885..2af67b87b7394c 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -521,9 +521,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (BartDecoder, BartEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -861,7 +862,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1118,7 +1119,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index d698cff88b146e..d30eff63f54131 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -517,7 +517,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -572,9 +572,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BeitEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BEIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 1b0fad3f9d6546..993160d6998a83 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -600,7 +600,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -762,9 +762,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index abe2d828b28bb9..c811f2d19d3187 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -408,7 +408,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -607,9 +607,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BERT_GENERATION_START_DOCSTRING = r""" diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index e266b1a67b7d41..6677e658b8dd50 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -1624,7 +1624,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -1784,9 +1784,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BigBirdEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BIG_BIRD_START_DOCSTRING = r""" diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 4e279f9dc059fe..98d8ae83179f5d 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1609,9 +1609,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (BigBirdPegasusDecoder, BigBirdPegasusEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -1950,7 +1951,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -2297,7 +2298,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index ca084db5c7d0b9..6597d2ea04e659 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -376,9 +376,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BioGptModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BIOGPT_START_DOCSTRING = r""" @@ -598,7 +599,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/bit/modeling_bit.py b/src/transformers/models/bit/modeling_bit.py index 12a5ecd42b74cf..d02861d6343da5 100644 --- a/src/transformers/models/bit/modeling_bit.py +++ b/src/transformers/models/bit/modeling_bit.py @@ -669,9 +669,10 @@ def _init_weights(self, module): nn.init.constant_(module.weight, 1) nn.init.constant_(module.bias, 0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BitModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 1db81905210b63..caaf59d289a293 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -483,9 +483,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (BlenderbotDecoder, BlenderbotEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -784,7 +785,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1040,7 +1041,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 129de3dd1456e3..d72ee4ceb5589b 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -480,9 +480,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (BlenderbotSmallDecoder, BlenderbotSmallEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -782,7 +783,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1037,7 +1038,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 9fca7c28a1a07d..59f2590d04ee00 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -461,9 +461,10 @@ def _init_weights(self, module): elif isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BlipEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BLIP_START_DOCSTRING = r""" @@ -629,7 +630,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 49b958afc2ebae..317eea1e1b6e4a 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -429,7 +429,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index bd56b17e55c21b..bcb6f4f7b6c6f4 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -297,9 +297,10 @@ def _init_weights(self, module): elif isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Blip2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BLIP_2_START_DOCSTRING = r""" @@ -480,7 +481,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -951,7 +952,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index d90bb6ad8fdfd5..688415ac11218b 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -496,9 +496,10 @@ def _init_weights(self, module: nn.Module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False): + def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None): if isinstance(module, BloomModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @staticmethod def _convert_to_standard_cache( @@ -769,7 +770,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, alibi, diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index ce569157b811c2..d64532170bbffa 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -811,7 +811,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index a8ea8d49195b88..603dc2a52b8fdb 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -658,7 +658,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, bbox_pos_emb, diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 8d7d279579e3e9..44764f900abb9a 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -531,7 +531,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -625,9 +625,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CamembertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CAMEMBERT_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index 657104ad696535..9625e97ea28b7b 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -802,7 +802,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -919,9 +919,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CanineEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CANINE_START_DOCSTRING = r""" diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 7bab0aea6eb95d..1f4a42732d7deb 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -742,9 +742,10 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ChineseCLIPVisionEncoder) or isinstance(module, ChineseCLIPTextEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CHINESE_CLIP_START_DOCSTRING = r""" @@ -916,7 +917,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -1025,7 +1026,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, ) diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 1d17a518838734..ccee38322c0b84 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -946,7 +946,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask ) else: @@ -1602,7 +1602,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -1701,9 +1701,10 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ClapTextEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class ClapAudioModel(ClapPreTrainedModel): diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 3a894b9727c92b..e179244a1c326c 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -467,9 +467,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CLIPEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CLIP_START_DOCSTRING = r""" @@ -646,7 +647,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index 96f13217aaf821..385737bafe3368 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -479,9 +479,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CLIPSegEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CLIPSEG_START_DOCSTRING = r""" @@ -655,7 +656,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 172a45544bac0d..464eeebc9ba0d9 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -337,9 +337,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CodeGenModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CODEGEN_START_DOCSTRING = r""" @@ -548,7 +549,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index 15f24084f46995..2a4812eaf04897 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -1169,9 +1169,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ConditionalDetrDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CONDITIONAL_DETR_START_DOCSTRING = r""" @@ -1523,7 +1524,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index a6fccf5b72b443..927c026df777a8 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -264,9 +264,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ConvBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class SeparableConv1D(nn.Module): @@ -639,7 +640,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/convnext/modeling_convnext.py b/src/transformers/models/convnext/modeling_convnext.py index e6cf336517a563..e11112b5322266 100755 --- a/src/transformers/models/convnext/modeling_convnext.py +++ b/src/transformers/models/convnext/modeling_convnext.py @@ -296,9 +296,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ConvNextEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CONVNEXT_START_DOCSTRING = r""" diff --git a/src/transformers/models/convnextv2/modeling_convnextv2.py b/src/transformers/models/convnextv2/modeling_convnextv2.py index 3a268c713d502a..f1ff89bb124398 100644 --- a/src/transformers/models/convnextv2/modeling_convnextv2.py +++ b/src/transformers/models/convnextv2/modeling_convnextv2.py @@ -317,9 +317,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ConvNextV2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CONVNEXTV2_START_DOCSTRING = r""" diff --git a/src/transformers/models/cpmant/modeling_cpmant.py b/src/transformers/models/cpmant/modeling_cpmant.py index 6d2dc596fa65ff..8a6c744ed69ecc 100755 --- a/src/transformers/models/cpmant/modeling_cpmant.py +++ b/src/transformers/models/cpmant/modeling_cpmant.py @@ -556,9 +556,10 @@ def _init_weights(self, module): elif isinstance(module, CpmAntSegmentPositionEmbedding): module.relative_attention_bias.data.normal_(mean=0.0, std=self.config.init_std) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CpmAntEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CPMANT_START_DOCSTRING = r""" diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 4435e9b8d01754..6d8bb5c2058c3a 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -300,7 +300,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(conv_layer), hidden_states, ) @@ -600,7 +600,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -761,9 +761,10 @@ def _get_feature_vector_attention_mask( attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Data2VecAudioEncoder, Data2VecAudioFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DATA2VEC_AUDIO_START_DOCSTRING = r""" diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index a521ccb39aaf0c..66588647f61bb5 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -517,7 +517,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -613,9 +613,10 @@ def _init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Data2VecTextEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DATA2VECTEXT_START_DOCSTRING = r""" diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index f8fe59587af0cc..e7fd98091f974c 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -529,7 +529,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -585,9 +585,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Data2VecVisionEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DATA2VEC_VISION_START_DOCSTRING = r""" diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 6f6c2af63a672e..06a33a7dd85c4f 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -464,7 +464,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(layer_module), next_kv, attention_mask, @@ -839,9 +839,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DebertaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DEBERTA_START_DOCSTRING = r""" diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index eda4f406cb316d..2172f5d22eefe2 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -508,7 +508,7 @@ def custom_forward(*inputs): return custom_forward - output_states = torch.utils.checkpoint.checkpoint( + output_states = self.gradient_checkpointing_func( create_custom_forward(layer_module), next_kv, attention_mask, @@ -938,9 +938,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DebertaV2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DEBERTA_START_DOCSTRING = r""" diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 8e5053a4160d12..3865fe523f7168 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -469,9 +469,10 @@ def _init_weights(self, module): # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DecisionTransformerGPT2Model): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): @@ -639,7 +640,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index f541ca130544dd..7e04d2a1c76069 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -1086,9 +1086,10 @@ def _init_weights(self, module): if hasattr(module, "level_embed"): nn.init.normal_(module.level_embed) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DeformableDetrDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DEFORMABLE_DETR_START_DOCSTRING = r""" @@ -1388,7 +1389,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, encoder_hidden_states, diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 38c28dbbedc669..ff95a458ad77ef 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -364,7 +364,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -415,9 +415,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: DeiTEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: DeiTEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, DeiTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DEIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/deprecated/mctct/modeling_mctct.py b/src/transformers/models/deprecated/mctct/modeling_mctct.py index eca5ba014e51a6..e38b89a0a4441c 100755 --- a/src/transformers/models/deprecated/mctct/modeling_mctct.py +++ b/src/transformers/models/deprecated/mctct/modeling_mctct.py @@ -504,9 +504,10 @@ def _get_feature_vector_attention_mask(self, feature_vector_length, attention_ma attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).long() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (MCTCTEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MCTCT_START_DOCSTRING = r""" @@ -623,7 +624,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 6853f5333f137c..f021714be25060 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -456,9 +456,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, OpenLlamaModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None OPEN_LLAMA_INPUTS_DOCSTRING = r""" @@ -673,7 +674,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py index 75415dbe77bf07..13a26b6c05d584 100644 --- a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py +++ b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py @@ -163,9 +163,10 @@ class TrajectoryTransformerPreTrainedModel(PreTrainedModel): main_input_name = "trajectories" supports_gradient_checkpointing = True - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TrajectoryTransformerModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): @@ -557,7 +558,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, layer_past, diff --git a/src/transformers/models/deprecated/van/modeling_van.py b/src/transformers/models/deprecated/van/modeling_van.py index 4ef18f54158f91..52c9e1242422c6 100644 --- a/src/transformers/models/deprecated/van/modeling_van.py +++ b/src/transformers/models/deprecated/van/modeling_van.py @@ -387,9 +387,10 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, VanModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VAN_START_DOCSTRING = r""" diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index 9cd29e94088730..2c5890e0a35733 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -977,9 +977,10 @@ def _init_weights(self, module): if hasattr(module, "level_embed"): nn.init.normal_(module.level_embed) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DetaDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DETA_START_DOCSTRING = r""" @@ -1280,7 +1281,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, encoder_hidden_states, diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 3dda00a20082cc..4200e6556d507b 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -925,9 +925,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DetrDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DETR_START_DOCSTRING = r""" @@ -1258,7 +1259,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, diff --git a/src/transformers/models/dinat/modeling_dinat.py b/src/transformers/models/dinat/modeling_dinat.py index 89c6ed2e2a88e9..eb4d3f2ff29646 100644 --- a/src/transformers/models/dinat/modeling_dinat.py +++ b/src/transformers/models/dinat/modeling_dinat.py @@ -660,7 +660,7 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: DinatEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: DinatEncoder, gradient_checkpointing_func=None) -> None: pass diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index 6e4446faddd5e9..656a3022c96f4a 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -454,7 +454,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -516,9 +516,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No std=self.config.initializer_range, ).to(module.cls_token.dtype) - def _set_gradient_checkpointing(self, module: Dinov2Encoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: Dinov2Encoder, gradient_checkpointing_func=None) -> None: if isinstance(module, Dinov2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DINOV2_START_DOCSTRING = r""" diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index f26b5846972d64..de3c125abbacdd 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -365,7 +365,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_state, attn_mask, @@ -430,9 +430,10 @@ def _init_weights(self, module: nn.Module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Transformer): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DISTILBERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index 0d833406e259e6..1a1e215f9a6d0d 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -756,7 +756,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask ) else: @@ -826,9 +826,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DonutSwinEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWIN_START_DOCSTRING = r""" diff --git a/src/transformers/models/dpr/modeling_dpr.py b/src/transformers/models/dpr/modeling_dpr.py index 944ce142b0ad02..c258343f6cfdc8 100644 --- a/src/transformers/models/dpr/modeling_dpr.py +++ b/src/transformers/models/dpr/modeling_dpr.py @@ -164,9 +164,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class DPREncoder(DPRPreTrainedModel): diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 187a6c36656a8e..b13ca04626cf57 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -535,7 +535,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -818,9 +818,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DPTViTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DPT_START_DOCSTRING = r""" diff --git a/src/transformers/models/efficientnet/modeling_efficientnet.py b/src/transformers/models/efficientnet/modeling_efficientnet.py index 478aeecee02bc1..d1b2c99403437b 100644 --- a/src/transformers/models/efficientnet/modeling_efficientnet.py +++ b/src/transformers/models/efficientnet/modeling_efficientnet.py @@ -500,9 +500,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, EfficientNetBlock): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index da3ee8e51d3602..a7d943450a864b 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -578,7 +578,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -692,9 +692,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ElectraEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/encodec/modeling_encodec.py b/src/transformers/models/encodec/modeling_encodec.py index 697fb3c94fbb1d..28c20da3d5eb15 100644 --- a/src/transformers/models/encodec/modeling_encodec.py +++ b/src/transformers/models/encodec/modeling_encodec.py @@ -473,9 +473,10 @@ def _init_weights(self, module): elif "bias" in name: nn.init.constant_(param, 0.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (EncodecEncoder, EncodecDecoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ENCODEC_START_DOCSTRING = r""" diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 3548e48c595a4a..d64860d6263e0a 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -265,10 +265,10 @@ def tie_weights(self): self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix ) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): # call both encoder and decoder function on gradient checkpointing - self.encoder._set_gradient_checkpointing(module, value=value) - self.decoder._set_gradient_checkpointing(module, value=value) + self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) + self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) def get_encoder(self): return self.encoder diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index d55155f80093bc..b178ca354495a4 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -513,7 +513,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -680,9 +680,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ErnieEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/ernie_m/modeling_ernie_m.py b/src/transformers/models/ernie_m/modeling_ernie_m.py index 9c53ddd73c8540..b26ee0fcafd19f 100755 --- a/src/transformers/models/ernie_m/modeling_ernie_m.py +++ b/src/transformers/models/ernie_m/modeling_ernie_m.py @@ -429,9 +429,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ErnieMEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ERNIE_M_START_DOCSTRING = r""" diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 7a07495ba7e501..3115a1357ea6ea 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -612,7 +612,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -710,9 +710,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, EsmEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ESM_START_DOCSTRING = r""" diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index e9dca6df989472..29873a39457fb5 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -945,9 +945,10 @@ def _init_weights(self, module: nn.Module): module.weight.data.fill_(1.0) # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._set_gradient_checkpointing with BloomModel->FalconModel - def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False): + def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None): if isinstance(module, FalconModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @staticmethod def _convert_cache_to_standard_format( @@ -1155,7 +1156,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, alibi, diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 8de647c8299a09..9b5faaeb15f643 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -670,7 +670,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -879,9 +879,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: FlavaEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: FlavaEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, FlavaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py index 45042147761d56..299b607b6b8adf 100755 --- a/src/transformers/models/fnet/modeling_fnet.py +++ b/src/transformers/models/fnet/modeling_fnet.py @@ -299,7 +299,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(layer_module), hidden_states) + layer_outputs = self.gradient_checkpointing_func(create_custom_forward(layer_module), hidden_states) else: layer_outputs = layer_module(hidden_states) @@ -431,9 +431,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, FNetEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index 8d18a8c63fda1b..0e33dc4f66f438 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -593,7 +593,7 @@ def custom_forward(*inputs): return custom_forward - stage_outputs = torch.utils.checkpoint.checkpoint( + stage_outputs = self.gradient_checkpointing_func( create_custom_forward(stage_module), hidden_states, input_dimensions, @@ -659,9 +659,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, FocalNetEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None FOCALNET_START_DOCSTRING = r""" diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index b14b1b0b871d3a..141976ef21c3d2 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -70,9 +70,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, FuyuForCausalLM): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None FUYU_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 00707e42dd085a..bcbee566fa247b 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -459,7 +459,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -533,9 +533,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (GitEncoder, GitVisionEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GIT_START_DOCSTRING = r""" @@ -885,7 +886,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 838e7ca2992520..fd726627bb1b81 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -480,9 +480,10 @@ def _init_weights(self, module): # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPT2Model): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass @@ -885,7 +886,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index be90f61e45bf1b..3bcb4a86581262 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -405,9 +405,10 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) # Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with GPT2->GPTBigCode - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTBigCodeModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GPT_BIGCODE_START_DOCSTRING = r""" @@ -658,7 +659,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 3ad49554c0ac8f..494187a33aa452 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -384,9 +384,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTNeoModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GPT_NEO_START_DOCSTRING = r""" @@ -612,7 +613,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 9391805a77b851..19560dc6c97531 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -78,9 +78,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTNeoXModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class GPTNeoXAttention(nn.Module): @@ -649,7 +650,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 98753edeb544f8..c1c5527a465531 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -66,9 +66,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTNeoXJapaneseModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class GPTNeoXJapaneseAttention(nn.Module): diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 6b5607f235b1a6..a51d4bdd094c7d 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -361,9 +361,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTJModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GPTJ_START_DOCSTRING = r""" @@ -675,7 +676,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py index 24917fcfdb075d..84d956c9f57e60 100644 --- a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py +++ b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py @@ -759,9 +759,10 @@ def _init_weights(self, module): module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (GPTSanJapaneseAttention,)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right def _shift_right(self, input_ids): diff --git a/src/transformers/models/graphormer/modeling_graphormer.py b/src/transformers/models/graphormer/modeling_graphormer.py index 8247745a3bc3ef..68ed6d265e70ae 100755 --- a/src/transformers/models/graphormer/modeling_graphormer.py +++ b/src/transformers/models/graphormer/modeling_graphormer.py @@ -772,9 +772,10 @@ def _init_weights( module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GraphormerModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class GraphormerModel(GraphormerPreTrainedModel): diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index 59ff60ed765a51..d4199891f6c96c 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -805,9 +805,10 @@ def _init_weights(self, module): nn.init.normal_(module.fc1.weight, std=fc_std) nn.init.normal_(module.fc2.weight, std=in_proj_std) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (GroupViTTextEncoder, GroupViTVisionEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GROUPVIT_START_DOCSTRING = r""" @@ -1038,7 +1039,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 1a7bde45efc128..9acb52c2aedb08 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -353,7 +353,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(conv_layer), hidden_states, ) @@ -738,7 +738,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -828,7 +828,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -895,9 +895,10 @@ def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (HubertEncoder, HubertEncoderStableLayerNorm)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 316f36561308f0..d3f9c5da4d2d7a 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -978,9 +978,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, IdeficsModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LLAMA_INPUTS_DOCSTRING = r""" @@ -1339,7 +1340,7 @@ def vblock( ) use_cache = False - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( vblock, decoder_layer, hidden_states, diff --git a/src/transformers/models/idefics/vision.py b/src/transformers/models/idefics/vision.py index d4966a240d84eb..eb2b836169d663 100644 --- a/src/transformers/models/idefics/vision.py +++ b/src/transformers/models/idefics/vision.py @@ -408,7 +408,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 54edcd30fc870d..f3ebc9324260b9 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -525,9 +525,10 @@ def _init_weights(self, module): # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ImageGPTModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None IMAGEGPT_START_DOCSTRING = r""" @@ -824,7 +825,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index e7b35174ca7e60..5b93a16d3e0252 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -924,9 +924,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (InformerDecoder, InformerEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None INFORMER_START_DOCSTRING = r""" @@ -1222,14 +1223,14 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), ) if conv_layer is not None: - output = torch.utils.checkpoint.checkpoint(conv_layer, layer_outputs[0]) + output = self.gradient_checkpointing_func(conv_layer, layer_outputs[0]) layer_outputs = (output,) + layer_outputs[1:] else: layer_outputs = encoder_layer( @@ -1446,7 +1447,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 082900a6652f80..7b02ee85020cfb 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -304,9 +304,10 @@ def _init_weights(self, module): elif isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, InstructBlipEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None INSTRUCTBLIP_START_DOCSTRING = r""" @@ -469,7 +470,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -946,7 +947,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 884a2799728b47..82531ab7a455ee 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -494,7 +494,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -638,9 +638,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LayoutLMEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LAYOUTLM_START_DOCSTRING = r""" diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index ef970edfdc9103..30ff103bea7daa 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -446,7 +446,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -514,9 +514,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LayoutLMv2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def my_convert_sync_batchnorm(module, process_group=None): diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 30ab0a5e8620c3..42162dcfb2e54f 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -668,7 +668,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index f0c22ed9502c26..1029a7950a2e20 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -1155,9 +1155,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (LEDDecoder, LEDEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -1883,7 +1884,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -2150,7 +2151,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, diff --git a/src/transformers/models/levit/modeling_levit.py b/src/transformers/models/levit/modeling_levit.py index 0accc28391bde6..5acaaeba90048a 100644 --- a/src/transformers/models/levit/modeling_levit.py +++ b/src/transformers/models/levit/modeling_levit.py @@ -507,9 +507,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LevitModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LEVIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index 46fe2d3e9cd779..65c381fc50a935 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -521,7 +521,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layout_inputs, @@ -607,9 +607,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LiltEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LILT_START_DOCSTRING = r""" diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index b67719ac327162..5664a581ffb771 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -705,9 +705,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LlamaModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LLAMA_INPUTS_DOCSTRING = r""" @@ -921,7 +922,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids ) else: diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 33bf9a6f92684c..3b77ad46aed3da 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -1311,7 +1311,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -1439,9 +1439,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LongformerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LONGFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index c80d2349832c43..4c6ff76cc95d78 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -1341,9 +1341,10 @@ def _init_weights(self, module): ) # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._set_gradient_checkpointing with T5->LongT5 - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (LongT5Attention, LongT5Stack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->LongT5 def _shift_right(self, input_ids): diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index 6913ede09d1c7b..fde39d0999af64 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -795,7 +795,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), word_hidden_states, entity_hidden_states, @@ -920,9 +920,10 @@ def _init_weights(self, module: nn.Module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LukeEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LUKE_START_DOCSTRING = r""" diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 6db8bbb5213b14..264aff5b4aac4b 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -552,9 +552,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (M2M100Decoder, M2M100Encoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None M2M_100_START_DOCSTRING = r""" @@ -827,7 +828,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1074,7 +1075,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 69de5b2e7d0e6f..a0ab7192718bdc 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -500,9 +500,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalP if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (MarianDecoder, MarianEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -795,7 +796,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1045,7 +1046,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 530c66a0c80b36..9686b0a1d3051d 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -655,7 +655,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index e839b16f625777..9ec586a17bb34b 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -1871,7 +1871,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 87b91ed64b62d3..8502a6a368eacd 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -855,7 +855,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, @@ -1619,11 +1619,13 @@ def _init_weights(self, module: nn.Module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MaskFormerPixelLevelModule): - module.encoder.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.encoder.gradient_checkpointing = gradient_checkpointing_func is not None if isinstance(module, DetrDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index 357ac9d4aaca36..fe9dbc91f801a0 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -695,7 +695,7 @@ def custom_forward(*inputs): return custom_forward - layer_hidden_states, output_dimensions, layer_all_hidden_states = torch.utils.checkpoint.checkpoint( + layer_hidden_states, output_dimensions, layer_all_hidden_states = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask ) else: @@ -752,9 +752,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MaskFormerSwinEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel): diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index b53ad8848dd3c2..644c5d292b0eea 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -516,9 +516,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (MBartDecoder, MBartDecoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -835,7 +836,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1094,7 +1095,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 5d0ad6e3410c8f..16d463dcb470b1 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -558,7 +558,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -728,9 +728,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MegatronBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/mgp_str/modeling_mgp_str.py b/src/transformers/models/mgp_str/modeling_mgp_str.py index 5d1f5bea7bfd35..1257b4df39c015 100644 --- a/src/transformers/models/mgp_str/modeling_mgp_str.py +++ b/src/transformers/models/mgp_str/modeling_mgp_str.py @@ -333,9 +333,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: MgpstrEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: MgpstrEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, MgpstrEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MGP_STR_START_DOCSTRING = r""" diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index d650d60b8a553e..1544ebeaaf8180 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -689,9 +689,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MistralModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MISTRAL_INPUTS_DOCSTRING = r""" @@ -926,7 +927,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index c3accb21e05e42..0653321df9c38b 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -633,7 +633,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, ) @@ -672,9 +672,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MobileViTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MOBILEVIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py index 5a0e08d7344dc7..5aca04266e464f 100644 --- a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py +++ b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py @@ -589,7 +589,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, ) @@ -629,9 +629,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MobileViTV2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MOBILEVITV2_START_DOCSTRING = r""" diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index d760bec9854a8e..279b0bc903a5b2 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -294,9 +294,10 @@ def _init_weights(self, module: nn.Module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False): + def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None): if isinstance(module, MptModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @staticmethod def _convert_to_mpt_cache( @@ -531,7 +532,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, alibi, diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index d400fea6d23dda..672e2666533ded 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -773,7 +773,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -871,9 +871,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MraEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MRA_START_DOCSTRING = r""" diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 186db94dad7f20..2e2b68060dc923 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -845,9 +845,10 @@ def _init_weights(self, module): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (MT5Attention, MT5Stack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index bcc6bc82a2f5f4..6bee6f35dc7d60 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -475,9 +475,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MusicgenDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MUSICGEN_START_DOCSTRING = r""" @@ -1562,10 +1563,10 @@ def tie_weights(self): self.text_encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix ) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): # call both encoder and decoder function on gradient checkpointing - self.text_encoder._set_gradient_checkpointing(module, value=value) - self.decoder._set_gradient_checkpointing(module, value=value) + self.text_encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) + self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) def get_audio_encoder(self): return self.audio_encoder diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 5c1ed05249ef5c..f44e067aac3102 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -563,9 +563,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (MvpDecoder, MvpEncoder, MvpPrompt)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -956,7 +957,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1235,7 +1236,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/nat/modeling_nat.py b/src/transformers/models/nat/modeling_nat.py index ecc745b558dd71..4f7206a5e8ecf5 100644 --- a/src/transformers/models/nat/modeling_nat.py +++ b/src/transformers/models/nat/modeling_nat.py @@ -639,7 +639,7 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: NatEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: NatEncoder, gradient_checkpointing_func=None) -> None: pass diff --git a/src/transformers/models/nezha/modeling_nezha.py b/src/transformers/models/nezha/modeling_nezha.py index fa31e94f4d2e6b..5a94e43291cb7e 100644 --- a/src/transformers/models/nezha/modeling_nezha.py +++ b/src/transformers/models/nezha/modeling_nezha.py @@ -584,7 +584,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -752,9 +752,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, NezhaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 3701bbecef2e73..6c42ffa95b2de0 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -874,9 +874,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (NllbMoeDecoder, NllbMoeEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None NLLB_MOE_START_DOCSTRING = r""" @@ -1160,7 +1161,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index 51ee73ab72d317..3c5df5dedd2e46 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -377,7 +377,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -477,9 +477,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, NystromformerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None NYSTROMFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index 5b6220f8816949..165684542d859d 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -2616,7 +2616,7 @@ def __init__( def forward(self, hidden_states: torch.Tensor): for layer in self.layers: if self.use_checkpoint: - hidden_states = torch.utils.checkpoint.checkpoint(layer, hidden_states) + hidden_states = self.gradient_checkpointing_func(layer, hidden_states) else: hidden_states = layer(hidden_states) return hidden_states diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 8f3f246524348d..c97d57fa236f4e 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -411,9 +411,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (OPTDecoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None OPT_INPUTS_DOCSTRING = r""" @@ -699,7 +700,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, causal_attention_mask, diff --git a/src/transformers/models/owlv2/modeling_owlv2.py b/src/transformers/models/owlv2/modeling_owlv2.py index 451cc4a69126a5..5aee16cc81060a 100644 --- a/src/transformers/models/owlv2/modeling_owlv2.py +++ b/src/transformers/models/owlv2/modeling_owlv2.py @@ -584,9 +584,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Owlv2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None OWLV2_START_DOCSTRING = r""" @@ -771,7 +772,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1378,6 +1379,7 @@ def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor): def objectness_predictor(self, image_features: torch.FloatTensor) -> torch.FloatTensor: """Predicts the probability that each image feature token is an object. + Args: image_features (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_dim)`)): Features extracted from the image. diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index 66cfb8092df5df..b5317ea1c1b86b 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -576,9 +576,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, OwlViTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None OWLVIT_START_DOCSTRING = r""" @@ -760,7 +761,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 55856f7b06b6be..705cf956f78482 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -500,9 +500,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (PegasusDecoder, PegasusEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PEGASUS_START_DOCSTRING = r""" @@ -810,7 +811,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1095,7 +1096,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index e87e9c7164ab44..5f588842923169 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -780,9 +780,10 @@ def _init_weights(self, module): elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (PegasusXDecoder, PegasusXEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PEGASUS_X_START_DOCSTRING = r""" @@ -1078,7 +1079,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, global_hidden_states, @@ -1339,7 +1340,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index a0bc5726382336..c6092e158c93a5 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -467,9 +467,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, PersimmonModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PERSIMMON_INPUTS_DOCSTRING = r""" @@ -677,7 +678,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 58041820c1fb83..31cedc13359f95 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -350,7 +350,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -563,9 +563,10 @@ def __init__(self, config: Pix2StructConfig): # Initialize weights and apply final processing self.post_init() - def _set_gradient_checkpointing(self, module: Pix2StructVisionEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: Pix2StructVisionEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, Pix2StructVisionEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def get_input_embeddings(self): return self.embeddings.patch_projection @@ -1320,9 +1321,10 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] supports_gradient_checkpointing = True - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Pix2StructTextAttention, Pix2StructTextModel)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 3a880839236d43..a079b0bf0cf5be 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -517,9 +517,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (PLBartDecoder, PLBartEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PLBART_START_DOCSTRING = r""" @@ -814,7 +815,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1072,7 +1073,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/poolformer/modeling_poolformer.py b/src/transformers/models/poolformer/modeling_poolformer.py index 6acc8ec98e6939..209533e31990e2 100755 --- a/src/transformers/models/poolformer/modeling_poolformer.py +++ b/src/transformers/models/poolformer/modeling_poolformer.py @@ -282,9 +282,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, PoolFormerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None POOLFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 5a67b8044b0999..acb43f824b7bf7 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -739,9 +739,10 @@ def _init_weights(self, module): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Pop2PianoAttention, Pop2PianoStack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 241a9efea36aaf..04d2b946eafcf5 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -557,9 +557,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (ProphetNetDecoder, ProphetNetEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -1336,7 +1337,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, extended_attention_mask, @@ -1577,7 +1578,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, extended_attention_mask, diff --git a/src/transformers/models/pvt/modeling_pvt.py b/src/transformers/models/pvt/modeling_pvt.py index 2dd452ec1df153..356b7c14afa8e0 100755 --- a/src/transformers/models/pvt/modeling_pvt.py +++ b/src/transformers/models/pvt/modeling_pvt.py @@ -489,9 +489,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No std=self.config.initializer_range, ) - def _set_gradient_checkpointing(self, module: PvtEncoder, value: bool = False): + def _set_gradient_checkpointing(self, module: PvtEncoder, gradient_checkpointing_func=None): if isinstance(module, PvtEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PVT_START_DOCSTRING = r""" diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py index fead8fc0cf7f42..cf307fb35009f3 100755 --- a/src/transformers/models/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/qdqbert/modeling_qdqbert.py @@ -588,7 +588,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -757,9 +757,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, QDQBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None QDQBERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py index aa738d782b7b6d..8f7d0a65600296 100644 --- a/src/transformers/models/realm/modeling_realm.py +++ b/src/transformers/models/realm/modeling_realm.py @@ -593,7 +593,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/regnet/modeling_regnet.py b/src/transformers/models/regnet/modeling_regnet.py index 07ef29fd33320b..21050f07fda441 100644 --- a/src/transformers/models/regnet/modeling_regnet.py +++ b/src/transformers/models/regnet/modeling_regnet.py @@ -293,9 +293,10 @@ def _init_weights(self, module): nn.init.constant_(module.weight, 1) nn.init.constant_(module.bias, 0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RegNetModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None REGNET_START_DOCSTRING = r""" diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 235bff89f8a354..6dd04ed4030c47 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -550,7 +550,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -673,9 +673,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RemBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None REMBERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py index f2d207c2189f27..e6b1d85b2a46e8 100644 --- a/src/transformers/models/resnet/modeling_resnet.py +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -283,9 +283,10 @@ def _init_weights(self, module): nn.init.constant_(module.weight, 1) nn.init.constant_(module.bias, 0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ResNetEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None RESNET_START_DOCSTRING = r""" diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 6d4cc991d22ca0..d7ead17b45449c 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -517,7 +517,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -612,9 +612,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RobertaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ROBERTA_START_DOCSTRING = r""" diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index da1cd6331bc314..4ae7a308f68e5d 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -519,7 +519,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -615,9 +615,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RobertaPreLayerNormEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ROBERTA_PRELAYERNORM_START_DOCSTRING = r""" diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index a5b1b63050b1ef..d1b84ab31f631a 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -651,7 +651,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -796,9 +796,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RoCBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ROC_BERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index b9c36a305ff1cd..e860ff34eb5295 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -585,7 +585,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -715,9 +715,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RoFormerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ROFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index db41bd3c9538c0..bbe21949f0e67a 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -466,9 +466,10 @@ def _init_weights(self, module): module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RwkvModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass @@ -684,7 +685,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states, state, attentions = torch.utils.checkpoint.checkpoint( + hidden_states, state, attentions = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, state ) else: diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index abf5544a5b4de6..f5cd7cf0a45bd5 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -1049,7 +1049,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, ) diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 34f9c84235cc08..44cbcec5267af3 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -367,7 +367,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(conv_layer), hidden_states, ) @@ -680,7 +680,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -756,9 +756,10 @@ def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (SEWEncoder, SEWFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 661a8c03b1a5d9..74374e1a4eb98e 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -460,7 +460,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(conv_layer), hidden_states, ) @@ -1141,7 +1141,7 @@ def custom_forward(*inputs): return custom_forward - output_states = torch.utils.checkpoint.checkpoint( + output_states = self.gradient_checkpointing_func( create_custom_forward(layer_module), next_kv, attention_mask, @@ -1322,9 +1322,10 @@ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attenti attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, SEWDTransformerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SEWD_START_DOCSTRING = r""" diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index e80c26e2698d73..ec255fab9bc766 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -249,10 +249,10 @@ def __init__( f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" ) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): # call both encoder and decoder function on gradient checkpointing - self.encoder._set_gradient_checkpointing(module, value=value) - self.decoder._set_gradient_checkpointing(module, value=value) + self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) + self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) def get_encoder(self): return self.encoder diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 31c9b6cfe93552..acdcc2f902cb90 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -559,9 +559,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Speech2TextDecoder, Speech2TextEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ @@ -824,7 +825,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1073,7 +1074,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index f9b5dec4209273..9d863ba3e2f2fe 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -437,9 +437,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Speech2Text2Decoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SPEECH_TO_TEXT_2_START_DOCSTRING = r""" @@ -677,7 +678,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 9b8ab3d3805a05..ef374bbb32e789 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -527,7 +527,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(conv_layer), hidden_states, ) @@ -1281,9 +1281,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (SpeechT5Encoder, SpeechT5Decoder, SpeechT5FeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class SpeechT5Encoder(SpeechT5PreTrainedModel): @@ -1393,7 +1394,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1723,7 +1724,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index f72ffb10111bc7..f1ab50179dea6a 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -466,7 +466,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -544,9 +544,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, SplinterEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SPLINTER_START_DOCSTRING = r""" diff --git a/src/transformers/models/swiftformer/modeling_swiftformer.py b/src/transformers/models/swiftformer/modeling_swiftformer.py index ff72f87506d36a..4170ce153bbfeb 100644 --- a/src/transformers/models/swiftformer/modeling_swiftformer.py +++ b/src/transformers/models/swiftformer/modeling_swiftformer.py @@ -442,9 +442,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No nn.init.constant_(module.bias, 0) nn.init.constant_(module.weight, 1.0) - def _set_gradient_checkpointing(self, module: SwiftFormerEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: SwiftFormerEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, SwiftFormerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWIFTFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 45a7aa718cf026..228d962dea1dc5 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -832,7 +832,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask ) else: @@ -901,9 +901,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, SwinEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWIN_START_DOCSTRING = r""" diff --git a/src/transformers/models/swin/modeling_tf_swin.py b/src/transformers/models/swin/modeling_tf_swin.py index 02ec39edb0fe14..5d53561442457f 100644 --- a/src/transformers/models/swin/modeling_tf_swin.py +++ b/src/transformers/models/swin/modeling_tf_swin.py @@ -951,11 +951,6 @@ class TFSwinPreTrainedModel(TFPreTrainedModel): config_class = SwinConfig base_model_prefix = "swin" main_input_name = "pixel_values" - supports_gradient_checkpointing = True - - def _set_gradient_checkpointing(self, module, value=False) -> None: - if isinstance(module, TFSwinEncoder): - module.gradient_checkpointing = value SWIN_START_DOCSTRING = r""" diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py index a8a17bdf584b00..db8ff6a652eb63 100644 --- a/src/transformers/models/swin2sr/modeling_swin2sr.py +++ b/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -753,7 +753,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(stage_module), hidden_states, input_dimensions, layer_head_mask ) else: @@ -802,9 +802,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Swin2SREncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWIN2SR_START_DOCSTRING = r""" diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index a4224e16df3c25..fda0e080d0d8d0 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -913,7 +913,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask ) else: @@ -983,9 +983,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Swinv2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWINV2_START_DOCSTRING = r""" diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 0a402ea2d6af87..ed0d59abb8bc5d 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -865,9 +865,10 @@ def _init_weights(self, module): module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (SwitchTransformersAttention, SwitchTransformersStack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 0e7237ea36b647..603c6a4730e8f3 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -873,9 +873,10 @@ def _init_weights(self, module): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (T5Attention, T5Stack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index 8f59bd4b6e1785..fb42673ae5c55f 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -831,9 +831,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TableTransformerDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TABLE_TRANSFORMER_START_DOCSTRING = r""" @@ -1150,7 +1151,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index cdaa4b3e2725f7..e6ce415899faec 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -653,7 +653,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -778,9 +778,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TapasEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TAPAS_START_DOCSTRING = r""" diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 2caca5bd105131..c550f89e9504c0 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -663,9 +663,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (TimeSeriesTransformerDecoder, TimeSeriesTransformerEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TIME_SERIES_TRANSFORMER_START_DOCSTRING = r""" @@ -953,7 +954,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1171,7 +1172,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/timesformer/modeling_timesformer.py b/src/transformers/models/timesformer/modeling_timesformer.py index 676bcf7a5e27a0..df7dd1c953f5bf 100644 --- a/src/transformers/models/timesformer/modeling_timesformer.py +++ b/src/transformers/models/timesformer/modeling_timesformer.py @@ -446,7 +446,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, ) @@ -494,9 +494,10 @@ def _init_weights(self, module): nn.init.trunc_normal_(module.position_embeddings, std=self.config.initializer_range) module.patch_embeddings.apply(self._init_weights) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TimesformerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TIMESFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index c0541814be466e..6971b4dfb21a8f 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -454,9 +454,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TrOCRDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TROCR_START_DOCSTRING = r""" @@ -709,7 +710,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/tvlt/modeling_tvlt.py b/src/transformers/models/tvlt/modeling_tvlt.py index 464c3e76a11f94..8852083c4694cd 100644 --- a/src/transformers/models/tvlt/modeling_tvlt.py +++ b/src/transformers/models/tvlt/modeling_tvlt.py @@ -567,7 +567,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -616,9 +616,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TvltEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TVLT_START_DOCSTRING = r""" @@ -884,7 +885,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, None, diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index ffafd158114074..e6e9aaa26a380b 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -556,9 +556,10 @@ def _init_weights(self, module): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (UMT5Attention, UMT5Stack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index c475ab7f80f87d..8f667d3d564c48 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -391,7 +391,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(conv_layer), hidden_states, ) @@ -774,7 +774,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -864,7 +864,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -1039,9 +1039,10 @@ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attenti attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (UniSpeechEncoder, UniSpeechEncoderStableLayerNorm, UniSpeechFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None UNISPEECH_START_DOCSTRING = r""" diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 3fcc9549bbdcd6..5584929ab11c00 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -405,7 +405,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(conv_layer), hidden_states, ) @@ -788,7 +788,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -878,7 +878,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -1053,9 +1053,10 @@ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attenti attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (UniSpeechSatEncoder, UniSpeechSatEncoderStableLayerNorm, UniSpeechSatFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None UNISPEECH_SAT_START_DOCSTRING = r""" diff --git a/src/transformers/models/upernet/modeling_upernet.py b/src/transformers/models/upernet/modeling_upernet.py index b56b508d14ae63..04b8c94e1351bd 100644 --- a/src/transformers/models/upernet/modeling_upernet.py +++ b/src/transformers/models/upernet/modeling_upernet.py @@ -315,9 +315,10 @@ def init_weights(self): if self.auxiliary_head is not None: self.auxiliary_head.init_weights() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BackboneMixin): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None UPERNET_START_DOCSTRING = r""" diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index 07c32d14929037..9657a747fbc365 100644 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -441,7 +441,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -489,9 +489,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, VideoMAEEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIDEOMAE_START_DOCSTRING = r""" @@ -733,7 +734,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, None, diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index a36d58bd235bb5..9c8fee6c79f361 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -538,7 +538,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -591,9 +591,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ViltEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VILT_START_DOCSTRING = r""" diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index d3e464cbfffa08..84275cc33a767b 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -225,10 +225,10 @@ def __init__( f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" ) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): # call both encoder and decoder function on gradient checkpointing - self.encoder._set_gradient_checkpointing(module, value=value) - self.decoder._set_gradient_checkpointing(module, value=value) + self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) + self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) def get_encoder(self): return self.encoder diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index 81ad1068483a80..c2eaa90b48640e 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -425,7 +425,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -547,9 +547,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, VisualBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 8fdacdddf04cba..050d02ee2990c1 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -404,7 +404,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -467,9 +467,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No std=self.config.initializer_range, ).to(module.cls_token.dtype) - def _set_gradient_checkpointing(self, module: ViTEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: ViTEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, ViTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py index 008f6b3c9db536..fa4b3471e37579 100644 --- a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py +++ b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py @@ -422,7 +422,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -486,9 +486,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No std=self.config.initializer_range, ).to(module.cls_token.dtype) - def _set_gradient_checkpointing(self, module: ViTHybridEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: ViTHybridEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, ViTHybridEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index ef0c7c9f36869e..b468075d08ff3e 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -543,7 +543,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -591,9 +591,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ViTMAEEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIT_MAE_START_DOCSTRING = r""" @@ -800,7 +801,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, None, diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index 46639e7d622cb7..87779dd3ae944d 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -394,7 +394,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -444,9 +444,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: ViTMSNEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: ViTMSNEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, ViTMSNEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIT_MSN_START_DOCSTRING = r""" diff --git a/src/transformers/models/vitdet/modeling_vitdet.py b/src/transformers/models/vitdet/modeling_vitdet.py index e89fdbd7a33631..fd9f26923444cc 100644 --- a/src/transformers/models/vitdet/modeling_vitdet.py +++ b/src/transformers/models/vitdet/modeling_vitdet.py @@ -572,7 +572,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -666,9 +666,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.norm3.weight.data.zero_() module.norm3.bias.data.zero_() - def _set_gradient_checkpointing(self, module: VitDetEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: VitDetEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, VitDetEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VITDET_START_DOCSTRING = r""" diff --git a/src/transformers/models/vitmatte/modeling_vitmatte.py b/src/transformers/models/vitmatte/modeling_vitmatte.py index b23bdd21d56b85..18b8b80b328c31 100644 --- a/src/transformers/models/vitmatte/modeling_vitmatte.py +++ b/src/transformers/models/vitmatte/modeling_vitmatte.py @@ -86,9 +86,10 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BackboneMixin): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class VitMatteBasicConv3x3(nn.Module): diff --git a/src/transformers/models/vits/modeling_vits.py b/src/transformers/models/vits/modeling_vits.py index 49b9a1f1ae1551..7b7899ee287f31 100644 --- a/src/transformers/models/vits/modeling_vits.py +++ b/src/transformers/models/vits/modeling_vits.py @@ -1174,7 +1174,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, padding_mask, @@ -1296,9 +1296,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (VitsTextEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VITS_START_DOCSTRING = r""" diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index fd35668572a776..5e07b1544b2b2c 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -345,7 +345,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -414,9 +414,10 @@ def _init_weights(self, module): elif isinstance(module, nn.Parameter): module.data.normal_(mean=0.0, std=self.config.initializer_range) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, VivitEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIVIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index a6e02a0476f15d..c02e23660b5666 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -458,7 +458,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(conv_layer), hidden_states, ) @@ -810,7 +810,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -899,7 +899,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -1173,9 +1173,10 @@ def _get_feature_vector_attention_mask( attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2FeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_adapters(self): if self.config.adapter_attn_dim is None: diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index f162c514297067..edcdcf4a22ac47 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -525,7 +525,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(conv_layer), hidden_states, ) @@ -918,7 +918,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -1178,9 +1178,10 @@ def _get_feature_vector_attention_mask( attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None WAV2VEC2_CONFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index 5013837cbdcefd..182482dfd83ad6 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -361,7 +361,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(conv_layer), hidden_states, ) @@ -720,7 +720,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -811,7 +811,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -1052,9 +1052,10 @@ def _get_feature_vector_attention_mask( attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (WavLMEncoder, WavLMEncoderStableLayerNorm, WavLMFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None WAVLM_START_DOCSTRING = r""" diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 8962324471cafc..f4b8fb4852a98f 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -685,9 +685,10 @@ def _init_weights(self, module): embed_positions = module.embed_positions.weight embed_positions.copy_(sinusoids(*embed_positions.shape)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (WhisperDecoder, WhisperEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ @@ -949,7 +950,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, None, @@ -1182,7 +1183,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index da7eddff8df838..025533ab41781a 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -534,9 +534,10 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (XCLIPEncoder, XCLIPVisionEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None X_CLIP_START_DOCSTRING = r""" @@ -710,7 +711,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -957,7 +958,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 0c769dbbb5f324..16f0402abf98ce 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -503,9 +503,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, XGLMModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( @@ -682,7 +683,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py index cde05cfe8a8a68..e599cc3cede777 100644 --- a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -570,9 +570,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (XLMProphetNetDecoder, XLMProphetNetEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -1356,7 +1357,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, extended_attention_mask, @@ -1600,7 +1601,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, extended_attention_mask, diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index da454b1e3331f9..b195ee43723e51 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -518,7 +518,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -614,9 +614,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, XLMRobertaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None XLM_ROBERTA_START_DOCSTRING = r""" diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 26e0361abdb523..0e3ed4eeb986f1 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -506,7 +506,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 28fddc2fdbd6b5..2eb0ba83d7260c 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -580,7 +580,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, lang_ids, @@ -680,9 +680,10 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) # Copied from transformers.models.roberta.modeling_roberta.RobertaPreTrainedModel._set_gradient_checkpointing with Roberta->Xmod - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, XmodEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def set_default_language(self, language: str): """ diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index e3cb02ceae6ec0..0884529777c205 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -499,7 +499,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -551,9 +551,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: YolosEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: YolosEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, YolosEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None YOLOS_START_DOCSTRING = r""" diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index 5edd7f8835422a..0159d7fb76d32f 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -568,7 +568,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -668,9 +668,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, YosoEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None YOSO_START_DOCSTRING = r""" diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 02fcb7d2f511e1..ee583cec354899 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -550,7 +550,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -679,9 +679,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, {{cookiecutter.camelcase_modelname}}Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" @@ -2024,9 +2025,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ({{cookiecutter.camelcase_modelname}}Decoder, {{cookiecutter.camelcase_modelname}}Encoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" @@ -2319,7 +2321,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -2558,7 +2560,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 34f5bae3746f03..74ea66fd6fa01d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -569,6 +569,13 @@ def test_training_gradient_checkpointing(self): loss = model(**inputs).loss loss.backward() + model.gradient_checkpointing_disable() + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + def test_attention_outputs(self): if not self.has_attentions: self.skipTest(reason="Model does not output attentions") From 449b4a4c3cbd3d5edeab0dd11d34570fddcb7073 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 23 Oct 2023 13:32:57 +0000 Subject: [PATCH 02/12] fix --- .../models/seamless_m4t/modeling_seamless_m4t.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 7df6fcd98907ca..6b0e7a7ff079b2 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -900,7 +900,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, relative_position_embeddings, @@ -1547,9 +1547,10 @@ def _init_weights(self, module): k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) nn.init.uniform_(module.bias, a=-k, b=k) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (SeamlessM4TDecoder, SeamlessM4TEncoder, SeamlessM4TSpeechEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _compute_sub_sample_lengths_from_attention_mask(self, attention_mask): kernel_size, stride = self.config.adaptor_kernel_size, self.config.adaptor_stride @@ -1864,7 +1865,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -2139,7 +2140,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, From 6b4ab9f0c2a43e3a729982824bb00e481815a063 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 23 Oct 2023 14:37:13 +0000 Subject: [PATCH 03/12] remove `create_custom_forward` --- .../models/align/modeling_align.py | 11 +++------- .../models/altclip/modeling_altclip.py | 21 +++++-------------- .../modeling_audio_spectrogram_transformer.py | 10 ++------- .../models/autoformer/modeling_autoformer.py | 12 +++-------- src/transformers/models/bark/modeling_bark.py | 12 +++-------- src/transformers/models/bart/modeling_bart.py | 12 +++-------- src/transformers/models/beit/modeling_beit.py | 10 ++------- src/transformers/models/bert/modeling_bert.py | 11 +++------- .../modeling_bert_generation.py | 11 +++------- .../models/big_bird/modeling_big_bird.py | 11 +++------- .../modeling_bigbird_pegasus.py | 12 +++-------- .../models/biogpt/modeling_biogpt.py | 12 +++-------- .../models/blenderbot/modeling_blenderbot.py | 12 +++-------- .../modeling_blenderbot_small.py | 12 +++-------- src/transformers/models/blip/modeling_blip.py | 10 ++------- .../models/blip/modeling_blip_text.py | 11 +++------- .../models/blip_2/modeling_blip_2.py | 12 +++-------- .../models/bloom/modeling_bloom.py | 12 +++-------- .../bridgetower/modeling_bridgetower.py | 11 +++------- src/transformers/models/bros/modeling_bros.py | 10 ++------- .../models/camembert/modeling_camembert.py | 11 +++------- .../models/canine/modeling_canine.py | 10 ++------- .../chinese_clip/modeling_chinese_clip.py | 13 ++++-------- src/transformers/models/clap/modeling_clap.py | 20 ++++-------------- src/transformers/models/clip/modeling_clip.py | 10 ++------- .../models/clipseg/modeling_clipseg.py | 10 ++------- .../models/codegen/modeling_codegen.py | 12 +++-------- .../modeling_conditional_detr.py | 9 +------- .../models/convbert/modeling_convbert.py | 3 ++- .../data2vec/modeling_data2vec_audio.py | 11 ++-------- .../models/data2vec/modeling_data2vec_text.py | 11 +++------- .../data2vec/modeling_data2vec_vision.py | 10 ++------- .../models/deberta/modeling_deberta.py | 10 ++------- .../models/deberta_v2/modeling_deberta_v2.py | 10 ++------- .../modeling_decision_transformer.py | 12 +++-------- .../modeling_deformable_detr.py | 8 +------ src/transformers/models/deit/modeling_deit.py | 10 ++------- .../models/deprecated/mctct/modeling_mctct.py | 10 ++------- .../open_llama/modeling_open_llama.py | 12 +++-------- .../modeling_trajectory_transformer.py | 9 +------- src/transformers/models/deta/modeling_deta.py | 8 +------ src/transformers/models/detr/modeling_detr.py | 9 +------- .../models/dinov2/modeling_dinov2.py | 10 ++------- .../models/distilbert/modeling_distilbert.py | 10 ++------- .../models/donut/modeling_donut_swin.py | 9 +------- src/transformers/models/dpt/modeling_dpt.py | 10 ++------- .../models/electra/modeling_electra.py | 11 +++------- .../models/ernie/modeling_ernie.py | 11 +++------- src/transformers/models/esm/modeling_esm.py | 11 +++------- .../models/falcon/modeling_falcon.py | 13 ++++-------- .../models/flava/modeling_flava.py | 10 ++------- src/transformers/models/fnet/modeling_fnet.py | 9 +------- .../models/focalnet/modeling_focalnet.py | 9 +------- src/transformers/models/git/modeling_git.py | 21 +++++-------------- src/transformers/models/gpt2/modeling_gpt2.py | 12 +++-------- .../gpt_bigcode/modeling_gpt_bigcode.py | 12 +++-------- .../models/gpt_neo/modeling_gpt_neo.py | 12 +++-------- .../models/gpt_neox/modeling_gpt_neox.py | 13 ++++-------- src/transformers/models/gptj/modeling_gptj.py | 12 +++-------- .../models/groupvit/modeling_groupvit.py | 10 ++------- .../models/hubert/modeling_hubert.py | 13 +++--------- src/transformers/models/idefics/vision.py | 10 ++------- .../models/imagegpt/modeling_imagegpt.py | 12 +++-------- .../models/informer/modeling_informer.py | 12 +++-------- .../instructblip/modeling_instructblip.py | 12 +++-------- .../models/layoutlm/modeling_layoutlm.py | 11 +++------- .../models/layoutlmv2/modeling_layoutlmv2.py | 3 ++- .../models/layoutlmv3/modeling_layoutlmv3.py | 13 +----------- src/transformers/models/led/modeling_led.py | 13 ++++-------- src/transformers/models/lilt/modeling_lilt.py | 10 ++------- .../models/llama/modeling_llama.py | 17 +++++++-------- .../models/longformer/modeling_longformer.py | 11 +++------- .../models/longt5/modeling_longt5.py | 11 +++------- src/transformers/models/luke/modeling_luke.py | 10 ++------- .../models/m2m_100/modeling_m2m_100.py | 12 +++-------- .../models/marian/modeling_marian.py | 12 +++-------- .../models/markuplm/modeling_markuplm.py | 11 +++------- .../mask2former/modeling_mask2former.py | 10 ++------- .../models/maskformer/modeling_maskformer.py | 10 ++------- .../maskformer/modeling_maskformer_swin.py | 12 ++++------- .../models/mbart/modeling_mbart.py | 12 +++-------- .../megatron_bert/modeling_megatron_bert.py | 11 +++------- .../models/mistral/modeling_mistral.py | 14 +++++-------- .../models/mobilevit/modeling_mobilevit.py | 9 +------- .../mobilevitv2/modeling_mobilevitv2.py | 9 +------- src/transformers/models/mpt/modeling_mpt.py | 12 +++-------- src/transformers/models/mra/modeling_mra.py | 9 +------- src/transformers/models/mt5/modeling_mt5.py | 11 +++------- .../models/musicgen/modeling_musicgen.py | 12 +++-------- src/transformers/models/mvp/modeling_mvp.py | 12 +++-------- .../models/nezha/modeling_nezha.py | 11 +++------- .../models/nllb_moe/modeling_nllb_moe.py | 12 +++-------- .../nystromformer/modeling_nystromformer.py | 10 ++------- src/transformers/models/opt/modeling_opt.py | 12 +++-------- .../models/owlv2/modeling_owlv2.py | 10 ++------- .../models/owlvit/modeling_owlvit.py | 10 ++------- .../models/pegasus/modeling_pegasus.py | 12 +++-------- .../models/pegasus_x/modeling_pegasus_x.py | 12 +++-------- .../models/persimmon/modeling_persimmon.py | 12 +++-------- .../models/pix2struct/modeling_pix2struct.py | 12 +++-------- .../models/plbart/modeling_plbart.py | 12 +++-------- .../models/pop2piano/modeling_pop2piano.py | 11 +++------- .../models/prophetnet/modeling_prophetnet.py | 12 +++-------- .../models/qdqbert/modeling_qdqbert.py | 11 +++------- .../models/realm/modeling_realm.py | 11 +++------- .../models/rembert/modeling_rembert.py | 11 +++------- .../models/roberta/modeling_roberta.py | 11 +++------- .../modeling_roberta_prelayernorm.py | 11 +++------- .../models/roc_bert/modeling_roc_bert.py | 11 +++------- .../models/roformer/modeling_roformer.py | 11 +++------- src/transformers/models/rwkv/modeling_rwkv.py | 10 +-------- src/transformers/models/sam/modeling_sam.py | 9 +------- .../seamless_m4t/modeling_seamless_m4t.py | 8 +------ src/transformers/models/sew/modeling_sew.py | 11 ++-------- .../models/sew_d/modeling_sew_d.py | 19 +++-------------- .../speech_to_text/modeling_speech_to_text.py | 12 +++-------- .../modeling_speech_to_text_2.py | 10 +-------- .../models/speecht5/modeling_speecht5.py | 13 +++--------- .../models/splinter/modeling_splinter.py | 11 +++------- src/transformers/models/swin/modeling_swin.py | 9 +------- .../models/swin2sr/modeling_swin2sr.py | 9 +------- .../models/swinv2/modeling_swinv2.py | 9 +------- .../modeling_switch_transformers.py | 11 +++------- src/transformers/models/t5/modeling_t5.py | 11 +++------- .../modeling_table_transformer.py | 9 +------- .../models/tapas/modeling_tapas.py | 11 +++------- .../modeling_time_series_transformer.py | 12 +++-------- .../timesformer/modeling_timesformer.py | 10 ++------- .../models/trocr/modeling_trocr.py | 12 +++-------- src/transformers/models/tvlt/modeling_tvlt.py | 5 +++-- src/transformers/models/umt5/modeling_umt5.py | 11 +++------- .../models/unispeech/modeling_unispeech.py | 13 +++--------- .../unispeech_sat/modeling_unispeech_sat.py | 13 +++--------- .../models/videomae/modeling_videomae.py | 12 +++-------- src/transformers/models/vilt/modeling_vilt.py | 3 ++- .../visual_bert/modeling_visual_bert.py | 10 ++------- src/transformers/models/vit/modeling_vit.py | 10 ++------- .../models/vit_hybrid/modeling_vit_hybrid.py | 10 ++------- .../models/vit_mae/modeling_vit_mae.py | 12 +++-------- .../models/vit_msn/modeling_vit_msn.py | 10 ++------- .../models/vitdet/modeling_vitdet.py | 10 ++------- src/transformers/models/vits/modeling_vits.py | 10 ++------- .../models/vivit/modeling_vivit.py | 10 ++------- .../models/wav2vec2/modeling_wav2vec2.py | 13 +++--------- .../modeling_wav2vec2_conformer.py | 11 ++-------- .../models/wavlm/modeling_wavlm.py | 13 +++--------- .../models/whisper/modeling_whisper.py | 12 +++-------- .../models/x_clip/modeling_x_clip.py | 12 +++-------- src/transformers/models/xglm/modeling_xglm.py | 11 +++------- .../xlm_prophetnet/modeling_xlm_prophetnet.py | 12 +++-------- .../xlm_roberta/modeling_xlm_roberta.py | 11 +++------- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 11 +++------- src/transformers/models/xmod/modeling_xmod.py | 10 +++------ .../models/yolos/modeling_yolos.py | 9 ++------ src/transformers/models/yoso/modeling_yoso.py | 10 ++------- ...ng_{{cookiecutter.lowercase_modelname}}.py | 14 +++++-------- 156 files changed, 406 insertions(+), 1317 deletions(-) diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index bad7db0150c85b..e132fae5c660e6 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -1095,20 +1095,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 56b9657aecb700..71e650adba1bff 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -646,20 +646,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -960,18 +955,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index 2a895dc073ba71..1c79f3cfd78b21 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -336,17 +336,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 278811d23d9a34..520a9ddbc13124 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -1208,18 +1208,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1435,7 +1429,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 8ffb22fd3e5d9f..11c53ccbdb2192 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -638,20 +638,14 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(block), + block.forward, hidden_states, None, attention_mask, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 2af67b87b7394c..70013fba27dfff 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -855,18 +855,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1120,7 +1114,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index d30eff63f54131..860de96323be6a 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -510,17 +510,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: relative_position_bias = ( diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 993160d6998a83..b251c9c9b55916 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -593,20 +593,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index c811f2d19d3187..97fb89e95413d6 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -401,20 +401,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 6677e658b8dd50..890eb8c6875f3e 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -1617,15 +1617,8 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, @@ -1635,6 +1628,8 @@ def custom_forward(*inputs): from_mask, to_mask, blocked_encoder_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 98d8ae83179f5d..ad0640b04451f2 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1944,15 +1944,8 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1961,6 +1954,7 @@ def custom_forward(*inputs): to_mask, blocked_encoder_mask, blocked_encoder_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -2299,7 +2293,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 6597d2ea04e659..7dc72aa6368ecd 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -591,20 +591,14 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, head_mask[idx] if head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index caaf59d289a293..4a3248e5d44356 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -778,18 +778,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1042,7 +1036,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index d72ee4ceb5589b..ef9d0e9643f5d6 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -776,18 +776,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1039,7 +1033,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 59f2590d04ee00..229afec0a81f8f 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -623,17 +623,11 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 317eea1e1b6e4a..a9decd052d375d 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -422,20 +422,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index bcb6f4f7b6c6f4..8339016efcb91c 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -474,17 +474,11 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -953,7 +947,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 688415ac11218b..83998421e131b0 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -762,21 +762,15 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(block), + block.forward, hidden_states, alibi, causal_mask, layer_past, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index d64532170bbffa..ea4c3cc285badd 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -804,20 +804,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index 603dc2a52b8fdb..60e753c95f8de3 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -651,21 +651,15 @@ def forward( "`use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, bbox_pos_emb, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 44764f900abb9a..d5d9f0ae488f20 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -524,20 +524,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index 9625e97ea28b7b..198e3376731adc 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -795,18 +795,12 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 1f4a42732d7deb..c96521493fd5aa 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -910,20 +910,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -1027,7 +1022,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, ) else: diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index ccee38322c0b84..7c6c9618c4536e 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -939,15 +939,8 @@ def forward( input_dimensions = self.input_resolutions[i] if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask + layer_module.forward, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( @@ -1595,20 +1588,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index e179244a1c326c..9e179753157baf 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -640,18 +640,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index 385737bafe3368..0bded11f9bc1da 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -649,18 +649,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 464eeebc9ba0d9..0a01e05044e47e 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -541,21 +541,15 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(block), + block.forward, hidden_states, None, attention_mask, position_ids, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index 2a4812eaf04897..d663e080df9323 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -1517,15 +1517,8 @@ def forward( # apply transformation query_sine_embed = query_sine_embed_before_transformation * pos_transformation if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, combined_attention_mask, object_queries, diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index 927c026df777a8..c040715c363092 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -641,12 +641,13 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 6d8bb5c2058c3a..71fd5a705990ba 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -293,15 +293,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(conv_layer), + conv_layer.forward, hidden_states, ) else: @@ -601,7 +594,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, ) diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 66588647f61bb5..ba5c6b97a965d9 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -510,20 +510,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index e7fd98091f974c..6c5c39e4957c52 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -522,17 +522,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: relative_position_bias = ( diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 06a33a7dd85c4f..a7816bae558bec 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -457,20 +457,14 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, next_kv, attention_mask, query_states, relative_pos, rel_embeddings, + output_attentions, ) else: hidden_states = layer_module( diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 2172f5d22eefe2..e536d376c59107 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -501,20 +501,14 @@ def forward( all_hidden_states = all_hidden_states + (output_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - output_states = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, next_kv, attention_mask, query_states, relative_pos, rel_embeddings, + output_attentions, ) else: output_states = layer_module( diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 3865fe523f7168..8146436cdc5186 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -632,22 +632,16 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(block), + block.forward, hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 7e04d2a1c76069..fb8ed41ce7125b 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -1383,14 +1383,8 @@ def forward( if self.gradient_checkpointing and self.training: - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, encoder_hidden_states, encoder_attention_mask, diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index ff95a458ad77ef..4cd8785ce535e0 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -357,17 +357,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) diff --git a/src/transformers/models/deprecated/mctct/modeling_mctct.py b/src/transformers/models/deprecated/mctct/modeling_mctct.py index e38b89a0a4441c..779b409470d920 100755 --- a/src/transformers/models/deprecated/mctct/modeling_mctct.py +++ b/src/transformers/models/deprecated/mctct/modeling_mctct.py @@ -617,18 +617,12 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index f021714be25060..5ab949b11ce3b4 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -666,20 +666,14 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, position_ids, None, + output_attentions, + None, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py index 13a26b6c05d584..8081a96430bcea 100644 --- a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py +++ b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py @@ -551,15 +551,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(block), + block.forward, hidden_states, layer_past, use_cache, diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index 2c5890e0a35733..ff24ed74856a24 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -1275,14 +1275,8 @@ def forward( if self.gradient_checkpointing and self.training: - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, encoder_hidden_states, encoder_attention_mask, diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 4200e6556d507b..cc370a5a0c7c9e 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -1252,15 +1252,8 @@ def forward( continue if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index 656a3022c96f4a..1fd39703bce305 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -447,17 +447,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index de3c125abbacdd..db48ac56fee310 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -358,18 +358,12 @@ def forward( all_hidden_states = all_hidden_states + (hidden_state,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_state, attn_mask, head_mask[i], + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index 1a1e215f9a6d0d..a789b7ef57ba46 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -749,15 +749,8 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask + layer_module.forward, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index b13ca04626cf57..513892740ed7c6 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -528,17 +528,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index a7d943450a864b..eee30624719ecf 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -571,20 +571,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index b178ca354495a4..d88563e778c790 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -506,20 +506,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 3115a1357ea6ea..21e480c8212b93 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -605,20 +605,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 29873a39457fb5..7f8e7db562d77f 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1148,21 +1148,16 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(block), + block.forward, hidden_states, alibi, causal_mask, position_ids, head_mask[i], + layer_past, + use_cache, + output_attentions, padding_mask, ) else: diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 9b5faaeb15f643..61431463215194 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -663,18 +663,12 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py index 299b607b6b8adf..f9ec022845f065 100755 --- a/src/transformers/models/fnet/modeling_fnet.py +++ b/src/transformers/models/fnet/modeling_fnet.py @@ -292,14 +292,7 @@ def forward(self, hidden_states, output_hidden_states=False, return_dict=True): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - layer_outputs = self.gradient_checkpointing_func(create_custom_forward(layer_module), hidden_states) + layer_outputs = self.gradient_checkpointing_func(layer_module.forward, hidden_states) else: layer_outputs = layer_module(hidden_states) diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index 0e33dc4f66f438..5ff1c99b94f3da 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -586,15 +586,8 @@ def forward( for i, stage_module in enumerate(self.stages): if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - stage_outputs = self.gradient_checkpointing_func( - create_custom_forward(stage_module), + stage_module.forward, hidden_states, input_dimensions, ) diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index bcbee566fa247b..0e44931eb99ea1 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -452,18 +452,13 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -879,18 +874,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index fd726627bb1b81..ee84cb1bc88fcf 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -878,22 +878,16 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(block), + block.forward, hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 3bcb4a86581262..7d4e77a4674f6a 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -651,22 +651,16 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(block), + block.forward, hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 494187a33aa452..6ede0829cd03a5 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -605,20 +605,14 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(block), + block.forward, hidden_states, None, attention_mask, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 19560dc6c97531..860552cde48527 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -642,20 +642,15 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for layer_past - return module(*inputs, use_cache, None, output_attentions) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, position_ids, head_mask[i], + use_cache, + None, + output_attentions, ) else: outputs = layer( diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index a51d4bdd094c7d..c0302b6c21a044 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -668,21 +668,15 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(block), + block.forward, hidden_states, None, attention_mask, position_ids, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index d4199891f6c96c..e4aeaf70996fec 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -1032,18 +1032,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 9acb52c2aedb08..94b2a205d8ca1c 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -346,15 +346,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(conv_layer), + conv_layer.forward, hidden_states, ) else: @@ -739,7 +732,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, ) @@ -829,7 +822,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, ) diff --git a/src/transformers/models/idefics/vision.py b/src/transformers/models/idefics/vision.py index eb2b836169d663..cb604909e1927c 100644 --- a/src/transformers/models/idefics/vision.py +++ b/src/transformers/models/idefics/vision.py @@ -401,18 +401,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index f3ebc9324260b9..187f39248fbc8e 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -817,22 +817,16 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(block), + block.forward, hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 5b93a16d3e0252..d5cd57dea7ccf5 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -1216,18 +1216,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) if conv_layer is not None: output = self.gradient_checkpointing_func(conv_layer, layer_outputs[0]) @@ -1448,7 +1442,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 7b02ee85020cfb..74c6d875f222de 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -463,17 +463,11 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -948,7 +942,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 82531ab7a455ee..dc094bd8ba0bff 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -487,20 +487,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index 30ff103bea7daa..fcb1dd37de7228 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -447,10 +447,11 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, + output_attentions, rel_pos=rel_pos, rel_2d_pos=rel_2d_pos, ) diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 42162dcfb2e54f..9afc855417fabd 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -657,19 +657,8 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - # return module(*inputs, past_key_value, output_attentions, rel_pos, rel_2d_pos) - # The above line will cause error: - # RuntimeError: Trying to backward through the graph a second time - # (or directly access saved tensors after they have already been freed). - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 1029a7950a2e20..5850923ffdca5c 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -1877,20 +1877,15 @@ def forward( layer_outputs = (None, None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, is_global_attn, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, head_mask[idx] if head_mask is not None else None, is_index_masked, is_index_global_attn, + is_global_attn, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -2152,7 +2147,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index 65c381fc50a935..2c7085aa822821 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -514,19 +514,13 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layout_inputs, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 5664a581ffb771..340f02abea076b 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -914,16 +914,15 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids + decoder_layer.forward, + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + padding_mask, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 3b77ad46aed3da..6ca8f61cfa4ca1 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -1304,20 +1304,15 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, is_global_attn, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, head_mask[idx] if head_mask is not None else None, is_index_masked, is_index_global_attn, + is_global_attn, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 4c6ff76cc95d78..b4d3c3ba495f58 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -1511,15 +1511,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -1529,6 +1522,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index fde39d0999af64..143932f924bf6a 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -788,19 +788,13 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, word_hidden_states, entity_hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 264aff5b4aac4b..080949adbeb679 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -821,18 +821,12 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1076,7 +1070,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index a0ab7192718bdc..7bf8aac0aef680 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -789,18 +789,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1047,7 +1041,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 9686b0a1d3051d..fc15c86e7a9460 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -648,20 +648,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index 9ec586a17bb34b..7d00b6b6d87127 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -1864,20 +1864,14 @@ def forward( continue if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, None, None, + output_attentions, ) else: diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 8502a6a368eacd..a941c0508a94f0 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -848,20 +848,14 @@ def forward( continue if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, combined_attention_mask, encoder_hidden_states, encoder_attention_mask, None, + output_attentions, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index fe9dbc91f801a0..dd6c45de8a56b0 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -688,15 +688,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_hidden_states, output_dimensions, layer_all_hidden_states = self.gradient_checkpointing_func( - create_custom_forward(layer_module), hidden_states, layer_head_mask + layer_module.forward, + hidden_states, + layer_head_mask, + output_attentions, ) else: layer_hidden_states, output_dimensions, layer_all_hidden_states = layer_module( diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 644c5d292b0eea..aa5f17215e90bd 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -829,18 +829,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1096,7 +1090,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 16d463dcb470b1..a2e2a39ec966ab 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -551,20 +551,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 1544ebeaaf8180..e9c17cc25ccf42 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -919,19 +919,15 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, position_ids, + past_key_value, + output_attentions, + use_cache, + padding_mask, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index 0653321df9c38b..1e8a8afa07ddd0 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -626,15 +626,8 @@ def forward( for i, layer_module in enumerate(self.layer): if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, ) else: diff --git a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py index 5aca04266e464f..c857915a8cca99 100644 --- a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py +++ b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py @@ -582,15 +582,8 @@ def forward( for i, layer_module in enumerate(self.layer): if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, ) else: diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index 279b0bc903a5b2..897a90ce0486a1 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -524,20 +524,14 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(block), + block.forward, hidden_states, alibi, causal_mask, layer_past, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index 672e2666533ded..1da9da2af9159f 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -766,15 +766,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, ) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 2e2b68060dc923..2951ffc889dcdb 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1074,15 +1074,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -1092,6 +1085,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 6bee6f35dc7d60..a740ed47074bc2 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -827,16 +827,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, @@ -844,6 +836,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index f44e067aac3102..71a1a166d8483b 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -950,19 +950,13 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), (self_attn_prompt[idx] if self.use_prompt else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1237,7 +1231,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/nezha/modeling_nezha.py b/src/transformers/models/nezha/modeling_nezha.py index 5a94e43291cb7e..a8ad52d2698831 100644 --- a/src/transformers/models/nezha/modeling_nezha.py +++ b/src/transformers/models/nezha/modeling_nezha.py @@ -577,20 +577,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 72cf7e3a300586..883589de241038 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -1154,18 +1154,12 @@ def forward( layer_outputs = (None, None, None) else: if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1435,7 +1429,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = checkpoint( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index 3c5df5dedd2e46..9a023cbc91ef36 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -370,17 +370,11 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index c97d57fa236f4e..5782d796566a04 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -692,20 +692,14 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, causal_attention_mask, head_mask[idx] if head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/owlv2/modeling_owlv2.py b/src/transformers/models/owlv2/modeling_owlv2.py index 5aee16cc81060a..351a1a77d59a7e 100644 --- a/src/transformers/models/owlv2/modeling_owlv2.py +++ b/src/transformers/models/owlv2/modeling_owlv2.py @@ -765,18 +765,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index b5317ea1c1b86b..63e1570a110697 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -754,18 +754,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 705cf956f78482..dbe93bc18becc3 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -804,18 +804,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1097,7 +1091,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 5f588842923169..a29a1250a976e0 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -1072,18 +1072,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, global_hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1341,7 +1335,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index c6092e158c93a5..28a12d5eb33822 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -670,19 +670,13 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, position_ids, + past_key_value, + output_attentions, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 31cedc13359f95..acbe0996d5ae78 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -343,18 +343,12 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) @@ -1505,7 +1499,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index a079b0bf0cf5be..59803ed363e12f 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -808,18 +808,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1074,7 +1068,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index acb43f824b7bf7..5cf7039e9f0c2e 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -903,15 +903,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -921,6 +914,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 04d2b946eafcf5..f0016c8c206db7 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -1330,18 +1330,12 @@ def forward( encoder_hidden_states = encoder_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, extended_attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1579,7 +1573,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, extended_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py index cf307fb35009f3..69be03b93bded0 100755 --- a/src/transformers/models/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/qdqbert/modeling_qdqbert.py @@ -581,20 +581,15 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py index 8f7d0a65600296..a63e3a9e9bce6f 100644 --- a/src/transformers/models/realm/modeling_realm.py +++ b/src/transformers/models/realm/modeling_realm.py @@ -586,20 +586,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 6dd04ed4030c47..6471653da7bf74 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -543,20 +543,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index d7ead17b45449c..aedfc5ef807780 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -510,20 +510,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 4ae7a308f68e5d..1bcdb872451889 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -512,20 +512,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index d1b84ab31f631a..3627944fab4b95 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -644,20 +644,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index e860ff34eb5295..6773a6f967adb2 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -578,21 +578,16 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, sinusoidal_pos, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index bbe21949f0e67a..d7c7df9a839002 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -677,16 +677,8 @@ def forward( all_hidden_states = () if output_hidden_states else None for idx, block in enumerate(self.blocks): if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - hidden_states, state, attentions = self.gradient_checkpointing_func( - create_custom_forward(block), hidden_states, state + block.forward, hidden_states, state, use_cache, output_attentions ) else: hidden_states, state, attentions = block( diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index f5cd7cf0a45bd5..d384747af33653 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -1042,15 +1042,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, ) else: diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 6b0e7a7ff079b2..b0b56b80d268e4 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -1857,18 +1857,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 44cbcec5267af3..6a3cde064ff8ef 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -360,15 +360,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(conv_layer), + conv_layer.forward, hidden_states, ) else: @@ -681,7 +674,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, ) diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 74374e1a4eb98e..f18622538e41f4 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -453,15 +453,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(conv_layer), + conv_layer.forward, hidden_states, ) else: @@ -1134,20 +1127,14 @@ def forward( all_hidden_states = all_hidden_states + (output_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - output_states = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, next_kv, attention_mask, query_states, relative_pos, rel_embeddings, + output_attentions, ) else: output_states = layer_module( diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index acdcc2f902cb90..9d75dc4f3da0c6 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -818,18 +818,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1075,7 +1069,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index 9d863ba3e2f2fe..486dda2f46b4c7 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -670,16 +670,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index ef374bbb32e789..b470cab687d209 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -520,15 +520,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(conv_layer), + conv_layer.forward, hidden_states, ) else: @@ -1395,7 +1388,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1725,7 +1718,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index f1ab50179dea6a..d766f435f15010 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -459,20 +459,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 228d962dea1dc5..25432478abeaf1 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -825,15 +825,8 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask + layer_module.forward, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py index db8ff6a652eb63..d7b248b1135990 100644 --- a/src/transformers/models/swin2sr/modeling_swin2sr.py +++ b/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -746,15 +746,8 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(stage_module), hidden_states, input_dimensions, layer_head_mask + stage_module.forward, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = stage_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index fda0e080d0d8d0..c00ae39e0bec2d 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -906,15 +906,8 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask + layer_module.forward, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index ed0d59abb8bc5d..32d030728de579 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1040,15 +1040,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -1058,6 +1051,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 603c6a4730e8f3..c796a9cf24cfb0 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1101,15 +1101,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -1119,6 +1112,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index fb42673ae5c55f..e72975a200a075 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -1144,15 +1144,8 @@ def forward( continue if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index e6ce415899faec..ae22bbd8449d2c 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -646,20 +646,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_values, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_values, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index c550f89e9504c0..9b44713dc64aa3 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -947,18 +947,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1173,7 +1167,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/timesformer/modeling_timesformer.py b/src/transformers/models/timesformer/modeling_timesformer.py index df7dd1c953f5bf..ccc65287cdc20a 100644 --- a/src/transformers/models/timesformer/modeling_timesformer.py +++ b/src/transformers/models/timesformer/modeling_timesformer.py @@ -439,16 +439,10 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, output_attentions) diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 6971b4dfb21a8f..9b7fab8e2f3d4d 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -702,16 +702,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, @@ -719,6 +711,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/tvlt/modeling_tvlt.py b/src/transformers/models/tvlt/modeling_tvlt.py index 8852083c4694cd..fcf61142ced65f 100644 --- a/src/transformers/models/tvlt/modeling_tvlt.py +++ b/src/transformers/models/tvlt/modeling_tvlt.py @@ -568,10 +568,11 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) @@ -886,7 +887,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, None, ) diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index e6e9aaa26a380b..a5b58444fe4edf 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -710,15 +710,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, encoder_hidden_states, @@ -726,6 +719,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 8f667d3d564c48..4708ff4173dd3c 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -384,15 +384,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(conv_layer), + conv_layer.forward, hidden_states, ) else: @@ -775,7 +768,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, ) @@ -865,7 +858,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, ) diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 5584929ab11c00..2d57b063819854 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -398,15 +398,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(conv_layer), + conv_layer.forward, hidden_states, ) else: @@ -789,7 +782,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, ) @@ -879,7 +872,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, ) diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index 9657a747fbc365..27e09730cfde7f 100644 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -434,17 +434,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -735,7 +729,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, None, ) diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 9c8fee6c79f361..1d9db412d37f84 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -539,10 +539,11 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index c2eaa90b48640e..36a1292fc9fdb0 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -418,18 +418,12 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 050d02ee2990c1..b06ab62113a745 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -397,17 +397,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) diff --git a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py index fa4b3471e37579..7b54e6c1535b3f 100644 --- a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py +++ b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py @@ -415,17 +415,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index b468075d08ff3e..0e27a335ddb678 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -536,17 +536,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -802,7 +796,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, None, ) diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index 87779dd3ae944d..91e13c7b6adc9d 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -387,17 +387,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) diff --git a/src/transformers/models/vitdet/modeling_vitdet.py b/src/transformers/models/vitdet/modeling_vitdet.py index fd9f26923444cc..8e20f17e070920 100644 --- a/src/transformers/models/vitdet/modeling_vitdet.py +++ b/src/transformers/models/vitdet/modeling_vitdet.py @@ -565,17 +565,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) diff --git a/src/transformers/models/vits/modeling_vits.py b/src/transformers/models/vits/modeling_vits.py index 7b7899ee287f31..f3cda24b85cc15 100644 --- a/src/transformers/models/vits/modeling_vits.py +++ b/src/transformers/models/vits/modeling_vits.py @@ -1167,18 +1167,12 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, padding_mask, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index 5e07b1544b2b2c..b4ed99bd9e98a1 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -338,17 +338,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index c02e23660b5666..2e18a26633cfe0 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -451,15 +451,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(conv_layer), + conv_layer.forward, hidden_states, ) else: @@ -811,7 +804,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, ) @@ -900,7 +893,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, ) diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index edcdcf4a22ac47..6f2c28624df799 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -518,15 +518,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(conv_layer), + conv_layer.forward, hidden_states, ) else: @@ -919,7 +912,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, relative_position_embeddings, diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index 182482dfd83ad6..defe32a5103c8f 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -354,15 +354,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(conv_layer), + conv_layer.forward, hidden_states, ) else: @@ -721,7 +714,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, position_bias, @@ -812,7 +805,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, position_bias, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index f4b8fb4852a98f..0b7341fa6d3de6 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -943,18 +943,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, None, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1184,7 +1178,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index 025533ab41781a..de3a4376e4a608 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -704,18 +704,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -959,7 +953,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, causal_attention_mask, diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 16f0402abf98ce..f6f518e8f5ce8b 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -676,15 +676,8 @@ def forward( if self.gradient_checkpointing and self.training: - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, @@ -692,6 +685,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py index e599cc3cede777..e07d343b62cbfe 100644 --- a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -1350,18 +1350,12 @@ def forward( encoder_hidden_states = encoder_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, extended_attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1602,7 +1596,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, extended_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index b195ee43723e51..1bc22ca1004580 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -511,20 +511,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 0e3ed4eeb986f1..3477d709ae0e33 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -499,20 +499,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 2eb0ba83d7260c..c6fcc0bb7c2130 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -574,20 +574,16 @@ def forward( if self.gradient_checkpointing and self.training: - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, lang_ids, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index 0884529777c205..4e1825a457bc05 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -493,16 +493,11 @@ def forward( if self.gradient_checkpointing and self.training: - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index 0159d7fb76d32f..b0cbd589b293b8 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -561,17 +561,11 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index ee583cec354899..2071c90a83bb5e 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -544,19 +544,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -2322,7 +2318,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -2561,7 +2557,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, From 2b5a6695333dcd5a36fad3f5c547e88bdf03de15 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 23 Oct 2023 14:38:45 +0000 Subject: [PATCH 04/12] fixup --- src/transformers/models/seamless_m4t/modeling_seamless_m4t.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index ebd1361ed2e9d8..f7f74201d2d3d0 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -1862,7 +1862,7 @@ def forward( hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), - output_attentions + output_attentions, ) else: layer_outputs = encoder_layer( From 6fbe101677c9d06450a18f04be33dac8a41cf205 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 23 Oct 2023 14:46:07 +0000 Subject: [PATCH 05/12] fixup --- .../models/deformable_detr/modeling_deformable_detr.py | 1 - src/transformers/models/deta/modeling_deta.py | 1 - src/transformers/models/xglm/modeling_xglm.py | 1 - src/transformers/models/xmod/modeling_xmod.py | 1 - src/transformers/models/yolos/modeling_yolos.py | 1 - 5 files changed, 5 deletions(-) diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index fb8ed41ce7125b..507a1815149089 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -1382,7 +1382,6 @@ def forward( all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index ff24ed74856a24..0853a3f82208d6 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -1274,7 +1274,6 @@ def forward( all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index f6f518e8f5ce8b..075e0c3159704c 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -675,7 +675,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index c6fcc0bb7c2130..4ca4adeec995bd 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -573,7 +573,6 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( layer_module.forward, hidden_states, diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index 4e1825a457bc05..a378e96f9909c2 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -492,7 +492,6 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( layer_module.forward, hidden_states, From 634b5e7fe959f280176c0d5710bfe783f808974f Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 24 Oct 2023 09:17:42 +0000 Subject: [PATCH 06/12] add test and fix all failing GC tests --- src/transformers/models/align/modeling_align.py | 2 +- src/transformers/models/blip/modeling_blip.py | 4 ++-- src/transformers/models/blip_2/modeling_blip_2.py | 13 +++++-------- src/transformers/models/gpt2/modeling_gpt2.py | 1 - .../models/groupvit/modeling_groupvit.py | 1 - src/transformers/models/hubert/modeling_hubert.py | 2 +- .../models/instructblip/modeling_instructblip.py | 13 +++++-------- src/transformers/models/longt5/modeling_longt5.py | 4 +--- src/transformers/models/mbart/modeling_mbart.py | 2 +- .../models/pix2struct/modeling_pix2struct.py | 2 +- .../models/seamless_m4t/modeling_seamless_m4t.py | 2 +- src/transformers/models/sew_d/modeling_sew_d.py | 2 +- .../models/speecht5/modeling_speecht5.py | 6 ------ src/transformers/models/tvlt/modeling_tvlt.py | 2 +- .../models/videomae/modeling_videomae.py | 2 +- .../models/vit_mae/modeling_vit_mae.py | 2 +- .../models/vitmatte/modeling_vitmatte.py | 5 +++++ src/transformers/models/vits/modeling_vits.py | 2 +- tests/test_modeling_common.py | 14 ++++++++++++++ 19 files changed, 42 insertions(+), 39 deletions(-) diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index e132fae5c660e6..7b141b5f65a367 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -1193,7 +1193,7 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (AlignTextModel, AlignVisionModel)): + if isinstance(module, (AlignTextModel, AlignVisionModel, AlignTextEncoder)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 229afec0a81f8f..927c33f9927c08 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -34,7 +34,7 @@ replace_return_docstrings, ) from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig -from .modeling_blip_text import BlipTextLMHeadModel, BlipTextModel +from .modeling_blip_text import BlipTextEncoder, BlipTextLMHeadModel, BlipTextModel logger = logging.get_logger(__name__) @@ -462,7 +462,7 @@ def _init_weights(self, module): module.bias.data.zero_() def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, BlipEncoder): + if isinstance(module, (BlipEncoder, BlipTextEncoder)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 8339016efcb91c..735b81bc4229e1 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -298,10 +298,14 @@ def _init_weights(self, module): module.bias.data.zero_() def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, Blip2Encoder): + if isinstance(module, (Blip2Encoder, Blip2QFormerEncoder)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None + # Enable / disable GC for the language model as well + if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"): + self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func) + BLIP_2_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -939,13 +943,6 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions, query_length) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer_module.forward, hidden_states, diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index ee84cb1bc88fcf..dc28ed3640f472 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -1618,7 +1618,6 @@ def __init__(self, config): # Model parallel self.model_parallel = False self.device_map = None - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index e4aeaf70996fec..332b14d9961cb7 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -492,7 +492,6 @@ def __init__( self.group_token = nn.Parameter(torch.zeros(1, num_group_token, config.hidden_size)) else: self.group_token = None - self.gradient_checkpointing = False self.layers = nn.ModuleList([GroupViTEncoderLayer(config) for _ in range(depth)]) if num_group_token > 0: diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 94b2a205d8ca1c..b215063090d0a6 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -889,7 +889,7 @@ def _init_weights(self, module): module.bias.data.zero_() def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (HubertEncoder, HubertEncoderStableLayerNorm)): + if isinstance(module, (HubertFeatureEncoder, HubertEncoder, HubertEncoderStableLayerNorm)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 74c6d875f222de..3cc44efbe3618a 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -305,10 +305,14 @@ def _init_weights(self, module): module.bias.data.zero_() def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, InstructBlipEncoder): + if isinstance(module, (InstructBlipEncoder, InstructBlipQFormerEncoder)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None + # Enable / disable GC for the language model as well + if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"): + self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func) + INSTRUCTBLIP_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -934,13 +938,6 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions, query_length) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer_module.forward, hidden_states, diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index b4d3c3ba495f58..9abbfa2f2001f6 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -775,7 +775,6 @@ def __init__(self, config: LongT5Config, has_relative_attention_bias: bool = Fal if self.has_relative_attention_bias: self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) self.pruned_heads = set() - self.gradient_checkpointing = False # Relativen attention bias & Layer norm for global attention if self.has_relative_attention_bias: @@ -1340,9 +1339,8 @@ def _init_weights(self, module): mean=0.0, std=factor * ((d_model) ** -0.5) ) - # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._set_gradient_checkpointing with T5->LongT5 def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (LongT5Attention, LongT5Stack)): + if isinstance(module, (LongT5Attention, LongT5Stack, LongT5LocalAttention)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index aa5f17215e90bd..29bcd445a8f223 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -517,7 +517,7 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (MBartDecoder, MBartDecoder)): + if isinstance(module, (MBartDecoder, MBartEncoder)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index acbe0996d5ae78..0efaab7cec5954 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -558,7 +558,7 @@ def __init__(self, config: Pix2StructConfig): self.post_init() def _set_gradient_checkpointing(self, module: Pix2StructVisionEncoder, gradient_checkpointing_func=None) -> None: - if isinstance(module, Pix2StructVisionEncoder): + if isinstance(module, (Pix2StructVisionEncoder, Pix2StructVisionAttention)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index f7f74201d2d3d0..b0bc0a1350919b 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -1548,7 +1548,7 @@ def _init_weights(self, module): nn.init.uniform_(module.bias, a=-k, b=k) def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (SeamlessM4TDecoder, SeamlessM4TEncoder, SeamlessM4TSpeechEncoder)): + if isinstance(module, (SeamlessM4TDecoder, SeamlessM4TEncoder, SeamlessM4TConformerEncoder)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index f18622538e41f4..2dc2231e607335 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -1310,7 +1310,7 @@ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attenti return attention_mask def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, SEWDTransformerEncoder): + if isinstance(module, (SEWDEncoder, SEWDFeatureEncoder, SEWDTransformerEncoder)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index b470cab687d209..c1fef6df94d1f5 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -1433,7 +1433,6 @@ def __init__(self, config: SpeechT5Config): super().__init__(config) self.prenet = SpeechT5SpeechEncoderPrenet(config) self.wrapped_encoder = SpeechT5Encoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1470,7 +1469,6 @@ def __init__(self, config: SpeechT5Config): super().__init__(config) self.prenet = SpeechT5TextEncoderPrenet(config) self.wrapped_encoder = SpeechT5Encoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1513,7 +1511,6 @@ class SpeechT5EncoderWithoutPrenet(SpeechT5PreTrainedModel): def __init__(self, config: SpeechT5Config): super().__init__(config) self.wrapped_encoder = SpeechT5Encoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1782,7 +1779,6 @@ def __init__(self, config: SpeechT5Config): super().__init__(config) self.prenet = SpeechT5SpeechDecoderPrenet(config) self.wrapped_decoder = SpeechT5Decoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1830,7 +1826,6 @@ def __init__(self, config: SpeechT5Config): super().__init__(config) self.prenet = SpeechT5TextDecoderPrenet(config) self.wrapped_decoder = SpeechT5Decoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1883,7 +1878,6 @@ class SpeechT5DecoderWithoutPrenet(SpeechT5PreTrainedModel): def __init__(self, config: SpeechT5Config): super().__init__(config) self.wrapped_decoder = SpeechT5Decoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/tvlt/modeling_tvlt.py b/src/transformers/models/tvlt/modeling_tvlt.py index fcf61142ced65f..65da6a46339a43 100644 --- a/src/transformers/models/tvlt/modeling_tvlt.py +++ b/src/transformers/models/tvlt/modeling_tvlt.py @@ -618,7 +618,7 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, TvltEncoder): + if isinstance(module, (TvltEncoder, TvltDecoder)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index 27e09730cfde7f..203aacc3f36584 100644 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -484,7 +484,7 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, VideoMAEEncoder): + if isinstance(module, (VideoMAEEncoder, VideoMAEDecoder)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index 0e27a335ddb678..14e047c5acc85a 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -586,7 +586,7 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, ViTMAEEncoder): + if isinstance(module, (ViTMAEEncoder, ViTMAEDecoder)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None diff --git a/src/transformers/models/vitmatte/modeling_vitmatte.py b/src/transformers/models/vitmatte/modeling_vitmatte.py index 18b8b80b328c31..f5025a37e71c15 100644 --- a/src/transformers/models/vitmatte/modeling_vitmatte.py +++ b/src/transformers/models/vitmatte/modeling_vitmatte.py @@ -91,6 +91,11 @@ def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None + for backbone_module in module.modules(): + if hasattr(backbone_module, "gradient_checkpointing"): + backbone_module.gradient_checkpointing_func = gradient_checkpointing_func + backbone_module.gradient_checkpointing = gradient_checkpointing_func is not None + class VitMatteBasicConv3x3(nn.Module): """ diff --git a/src/transformers/models/vits/modeling_vits.py b/src/transformers/models/vits/modeling_vits.py index f3cda24b85cc15..49b8e1a6a40a18 100644 --- a/src/transformers/models/vits/modeling_vits.py +++ b/src/transformers/models/vits/modeling_vits.py @@ -1291,7 +1291,7 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (VitsTextEncoder)): + if isinstance(module, VitsEncoder): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 74ea66fd6fa01d..7e1c471badf417 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -349,10 +349,24 @@ def test_gradient_checkpointing_enable_disable(self): model.gradient_checkpointing_enable() self.assertTrue(model.is_gradient_checkpointing) + # Loop over all modules and check that relevant modules have gradient_checkpointing set to True + for n, m in model.named_modules(): + if hasattr(m, "gradient_checkpointing"): + self.assertTrue( + m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to True" + ) + # check disable works model.gradient_checkpointing_disable() self.assertFalse(model.is_gradient_checkpointing) + # Loop over all modules and check that relevant modules have gradient_checkpointing set to False + for n, m in model.named_modules(): + if hasattr(m, "gradient_checkpointing"): + self.assertFalse( + m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False" + ) + def test_save_load_fast_init_from_base(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() if config.__class__ not in MODEL_MAPPING: From 476d261db8fca68b04d652b70d785686666b0e77 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 24 Oct 2023 09:28:07 +0000 Subject: [PATCH 07/12] remove all remaining `create_custom_forward` methods --- .../models/autoformer/modeling_autoformer.py | 10 ++------- src/transformers/models/bart/modeling_bart.py | 10 ++------- .../modeling_bigbird_pegasus.py | 10 ++------- .../models/blenderbot/modeling_blenderbot.py | 10 ++------- .../modeling_blenderbot_small.py | 10 ++------- .../chinese_clip/modeling_chinese_clip.py | 8 +------ .../models/convbert/modeling_convbert.py | 7 ------- .../data2vec/modeling_data2vec_audio.py | 8 +------ .../models/hubert/modeling_hubert.py | 16 ++------------ .../models/informer/modeling_informer.py | 10 ++------- .../models/layoutlmv2/modeling_layoutlmv2.py | 7 ------- src/transformers/models/led/modeling_led.py | 10 ++------- .../models/m2m_100/modeling_m2m_100.py | 10 ++------- .../models/marian/modeling_marian.py | 10 ++------- .../models/mbart/modeling_mbart.py | 10 ++------- src/transformers/models/mvp/modeling_mvp.py | 10 ++------- .../models/nllb_moe/modeling_nllb_moe.py | 9 ++------ .../models/pegasus/modeling_pegasus.py | 10 ++------- .../models/pegasus_x/modeling_pegasus_x.py | 10 ++------- .../models/pix2struct/modeling_pix2struct.py | 9 ++------ .../models/plbart/modeling_plbart.py | 10 ++------- .../models/prophetnet/modeling_prophetnet.py | 10 ++------- .../seamless_m4t/modeling_seamless_m4t.py | 21 ++++--------------- src/transformers/models/sew/modeling_sew.py | 8 +------ .../speech_to_text/modeling_speech_to_text.py | 10 ++------- .../models/speecht5/modeling_speecht5.py | 18 +++------------- .../modeling_time_series_transformer.py | 10 ++------- src/transformers/models/tvlt/modeling_tvlt.py | 15 +------------ .../models/unispeech/modeling_unispeech.py | 16 ++------------ .../unispeech_sat/modeling_unispeech_sat.py | 16 ++------------ .../models/videomae/modeling_videomae.py | 8 +------ src/transformers/models/vilt/modeling_vilt.py | 7 ------- .../models/vit_mae/modeling_vit_mae.py | 8 +------ .../models/wav2vec2/modeling_wav2vec2.py | 16 ++------------ .../modeling_wav2vec2_conformer.py | 8 +------ .../models/wavlm/modeling_wavlm.py | 16 ++------------ .../models/whisper/modeling_whisper.py | 10 ++------- .../models/x_clip/modeling_x_clip.py | 8 +------ .../xlm_prophetnet/modeling_xlm_prophetnet.py | 10 ++------- ...ng_{{cookiecutter.lowercase_modelname}}.py | 17 +++------------ 40 files changed, 70 insertions(+), 366 deletions(-) diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 520a9ddbc13124..29073c3d57dd3e 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -1420,14 +1420,6 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1437,6 +1429,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 70013fba27dfff..390af1a825a753 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1105,14 +1105,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1122,6 +1114,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index ad0640b04451f2..03ef911970ad87 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2284,14 +2284,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -2301,6 +2293,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 4a3248e5d44356..35879ac1500a97 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -1027,14 +1027,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1044,6 +1036,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index ef9d0e9643f5d6..59ba6b9dd874b8 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -1024,14 +1024,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1041,6 +1033,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index c96521493fd5aa..a010d82fd9de5e 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -1014,16 +1014,10 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( encoder_layer.forward, hidden_states, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index c040715c363092..e240830214253b 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -633,13 +633,6 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer_module.forward, hidden_states, diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 71fd5a705990ba..5a2491571efaca 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -586,17 +586,11 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index b215063090d0a6..e5b1b1742e74ba 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -724,17 +724,11 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -814,17 +808,11 @@ def forward( # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index d5cd57dea7ccf5..423de7d819769a 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -1433,14 +1433,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1450,6 +1442,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index fcb1dd37de7228..03900bff907c91 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -439,13 +439,6 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer_module.forward, hidden_states, diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 5850923ffdca5c..3d4e3c26188c57 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -2138,14 +2138,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -2155,6 +2147,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 080949adbeb679..b9b672ca28291c 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -1061,14 +1061,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1078,6 +1070,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 7bf8aac0aef680..81a4d7b6f6b527 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -1032,14 +1032,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1049,6 +1041,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 29bcd445a8f223..341260efe45cbb 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1081,14 +1081,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1098,6 +1090,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 71a1a166d8483b..d8622fca958264 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -1222,14 +1222,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1241,6 +1233,8 @@ def custom_forward(*inputs): self_attn_prompt[idx] if self.use_prompt else None, cross_attn_prompt[idx] if self.use_prompt else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 883589de241038..51bbd56d2b58cc 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -1421,13 +1421,6 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( decoder_layer.forward, hidden_states, @@ -1437,6 +1430,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index dbe93bc18becc3..5fc671f25f4684 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -1082,14 +1082,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1099,6 +1091,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index a29a1250a976e0..f35bef20969d30 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -1326,14 +1326,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1341,6 +1333,8 @@ def custom_forward(*inputs): encoder_hidden_states, encoder_attention_mask, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 0efaab7cec5954..9b4444c56ee294 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -1491,13 +1491,6 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( layer_module.forward, hidden_states, @@ -1509,6 +1502,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 59803ed363e12f..cdd73be66d7ae3 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -1059,14 +1059,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1076,6 +1068,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index f0016c8c206db7..eb1b319fb19a4a 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -1564,14 +1564,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1585,6 +1577,8 @@ def custom_forward(*inputs): predict_relative_position_buckets, position_ids, None, + use_cache, + output_attentions, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index b0bc0a1350919b..a930d60ec9da4c 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -892,14 +892,7 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, @@ -2125,15 +2118,7 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, attention_mask, @@ -2142,6 +2127,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 6a3cde064ff8ef..883fab34fce208 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -666,17 +666,11 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 9d75dc4f3da0c6..030358ff033a1d 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -1060,14 +1060,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1077,6 +1069,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index c1fef6df94d1f5..40d30f366a2066 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -1380,19 +1380,13 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), position_bias, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1706,14 +1700,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1723,6 +1709,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 9b44713dc64aa3..349bc5d48adfe2 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -1158,14 +1158,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1175,6 +1167,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/tvlt/modeling_tvlt.py b/src/transformers/models/tvlt/modeling_tvlt.py index 65da6a46339a43..086cf66fd40dda 100644 --- a/src/transformers/models/tvlt/modeling_tvlt.py +++ b/src/transformers/models/tvlt/modeling_tvlt.py @@ -560,13 +560,6 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer_module.forward, hidden_states, @@ -879,17 +872,11 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer_module.forward, hidden_states, None, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 4708ff4173dd3c..bcfc4069c8a35f 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -760,17 +760,11 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -850,17 +844,11 @@ def forward( # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 2d57b063819854..778dbfad18a9e0 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -774,17 +774,11 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -864,17 +858,11 @@ def forward( # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index 203aacc3f36584..84ff258c58b812 100644 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -721,17 +721,11 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer_module.forward, hidden_states, None, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions) diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 1d9db412d37f84..a93dc99903e1bb 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -531,13 +531,6 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer_module.forward, hidden_states, diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index 14e047c5acc85a..5fa10ca9d1376c 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -788,17 +788,11 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer_module.forward, hidden_states, None, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 2e18a26633cfe0..ec38d6a11570ff 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -796,17 +796,11 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -885,17 +879,11 @@ def forward( # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index 6f2c28624df799..5d723592556843 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -904,18 +904,12 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, relative_position_embeddings, + output_attentions, ) else: layer_outputs = layer( diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index defe32a5103c8f..ef76b43330890b 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -706,18 +706,12 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, position_bias, + output_attentions, ) else: layer_outputs = layer( @@ -797,18 +791,12 @@ def forward( # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, position_bias, + output_attentions, ) else: layer_outputs = layer( diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 0b7341fa6d3de6..c868abe44c0edc 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1169,14 +1169,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1186,6 +1178,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, # past_key_value + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index de3a4376e4a608..46ad1fb719e7cb 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -945,18 +945,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( encoder_layer.forward, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py index e07d343b62cbfe..bc86fd6ff5c8fb 100644 --- a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -1587,14 +1587,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1608,6 +1600,8 @@ def custom_forward(*inputs): predict_relative_position_buckets, position_ids, None, + use_cache, + output_attentions, ) else: layer_outputs = decoder_layer( diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 2071c90a83bb5e..e2e36ac3682308 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -2310,18 +2310,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -2549,13 +2543,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -2565,6 +2552,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: From 465849c02d1b72a6974f5a3d47facbfeb04f45b4 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 24 Oct 2023 09:48:05 +0000 Subject: [PATCH 08/12] fix idefics bug --- src/transformers/models/idefics/modeling_idefics.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index d3f9c5da4d2d7a..5c2d6f996319fc 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -40,7 +40,7 @@ ) from .configuration_idefics import IdeficsConfig from .perceiver import IdeficsPerceiverResampler -from .vision import IdeficsVisionTransformer +from .vision import IdeficsVisionTransformer, IdeficsVisionEncoder logger = logging.get_logger(__name__) @@ -979,7 +979,7 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, IdeficsModel): + if isinstance(module, (IdeficsModel, IdeficsVisionEncoder)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None @@ -1099,7 +1099,6 @@ def __init__(self, config: IdeficsConfig): self.norm = IdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() From 7e5eeda035a5041904a59a23afeca827c3c541e5 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 24 Oct 2023 09:56:11 +0000 Subject: [PATCH 09/12] fixup --- src/transformers/models/idefics/modeling_idefics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 5c2d6f996319fc..28841903a1a3bb 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -40,7 +40,7 @@ ) from .configuration_idefics import IdeficsConfig from .perceiver import IdeficsPerceiverResampler -from .vision import IdeficsVisionTransformer, IdeficsVisionEncoder +from .vision import IdeficsVisionEncoder, IdeficsVisionTransformer logger = logging.get_logger(__name__) From 967ed0db036766dc1e9a964eca66a8c07df1b902 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 25 Oct 2023 09:04:43 +0000 Subject: [PATCH 10/12] replace with `__call__` --- src/transformers/models/align/modeling_align.py | 2 +- src/transformers/models/altclip/modeling_altclip.py | 4 ++-- .../modeling_audio_spectrogram_transformer.py | 2 +- src/transformers/models/autoformer/modeling_autoformer.py | 4 ++-- src/transformers/models/bark/modeling_bark.py | 2 +- src/transformers/models/bart/modeling_bart.py | 4 ++-- src/transformers/models/beit/modeling_beit.py | 2 +- src/transformers/models/bert/modeling_bert.py | 2 +- .../models/bert_generation/modeling_bert_generation.py | 2 +- src/transformers/models/big_bird/modeling_big_bird.py | 2 +- .../models/bigbird_pegasus/modeling_bigbird_pegasus.py | 4 ++-- src/transformers/models/biogpt/modeling_biogpt.py | 2 +- src/transformers/models/blenderbot/modeling_blenderbot.py | 4 ++-- .../models/blenderbot_small/modeling_blenderbot_small.py | 4 ++-- src/transformers/models/blip/modeling_blip.py | 2 +- src/transformers/models/blip/modeling_blip_text.py | 2 +- src/transformers/models/blip_2/modeling_blip_2.py | 4 ++-- src/transformers/models/bloom/modeling_bloom.py | 2 +- src/transformers/models/bridgetower/modeling_bridgetower.py | 2 +- src/transformers/models/bros/modeling_bros.py | 2 +- src/transformers/models/camembert/modeling_camembert.py | 2 +- src/transformers/models/canine/modeling_canine.py | 2 +- .../models/chinese_clip/modeling_chinese_clip.py | 4 ++-- src/transformers/models/clap/modeling_clap.py | 4 ++-- src/transformers/models/clip/modeling_clip.py | 2 +- src/transformers/models/clipseg/modeling_clipseg.py | 2 +- src/transformers/models/codegen/modeling_codegen.py | 2 +- .../models/conditional_detr/modeling_conditional_detr.py | 2 +- src/transformers/models/convbert/modeling_convbert.py | 2 +- src/transformers/models/data2vec/modeling_data2vec_audio.py | 4 ++-- src/transformers/models/data2vec/modeling_data2vec_text.py | 2 +- .../models/data2vec/modeling_data2vec_vision.py | 2 +- src/transformers/models/deberta/modeling_deberta.py | 2 +- src/transformers/models/deberta_v2/modeling_deberta_v2.py | 2 +- .../decision_transformer/modeling_decision_transformer.py | 2 +- .../models/deformable_detr/modeling_deformable_detr.py | 2 +- src/transformers/models/deit/modeling_deit.py | 2 +- src/transformers/models/deprecated/mctct/modeling_mctct.py | 2 +- .../models/deprecated/open_llama/modeling_open_llama.py | 2 +- .../modeling_trajectory_transformer.py | 2 +- src/transformers/models/deta/modeling_deta.py | 2 +- src/transformers/models/detr/modeling_detr.py | 2 +- src/transformers/models/dinov2/modeling_dinov2.py | 2 +- src/transformers/models/distilbert/modeling_distilbert.py | 2 +- src/transformers/models/donut/modeling_donut_swin.py | 2 +- src/transformers/models/dpt/modeling_dpt.py | 2 +- src/transformers/models/electra/modeling_electra.py | 2 +- src/transformers/models/ernie/modeling_ernie.py | 2 +- src/transformers/models/esm/modeling_esm.py | 2 +- src/transformers/models/falcon/modeling_falcon.py | 2 +- src/transformers/models/flava/modeling_flava.py | 2 +- src/transformers/models/fnet/modeling_fnet.py | 2 +- src/transformers/models/focalnet/modeling_focalnet.py | 2 +- src/transformers/models/git/modeling_git.py | 4 ++-- src/transformers/models/gpt2/modeling_gpt2.py | 2 +- src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 2 +- src/transformers/models/gpt_neo/modeling_gpt_neo.py | 2 +- src/transformers/models/gpt_neox/modeling_gpt_neox.py | 2 +- src/transformers/models/gptj/modeling_gptj.py | 2 +- src/transformers/models/groupvit/modeling_groupvit.py | 2 +- src/transformers/models/hubert/modeling_hubert.py | 6 +++--- src/transformers/models/idefics/vision.py | 2 +- src/transformers/models/imagegpt/modeling_imagegpt.py | 2 +- src/transformers/models/informer/modeling_informer.py | 4 ++-- .../models/instructblip/modeling_instructblip.py | 4 ++-- src/transformers/models/layoutlm/modeling_layoutlm.py | 2 +- src/transformers/models/layoutlmv2/modeling_layoutlmv2.py | 2 +- src/transformers/models/layoutlmv3/modeling_layoutlmv3.py | 2 +- src/transformers/models/led/modeling_led.py | 4 ++-- src/transformers/models/lilt/modeling_lilt.py | 2 +- src/transformers/models/llama/modeling_llama.py | 2 +- src/transformers/models/longformer/modeling_longformer.py | 2 +- src/transformers/models/luke/modeling_luke.py | 2 +- src/transformers/models/m2m_100/modeling_m2m_100.py | 4 ++-- src/transformers/models/marian/modeling_marian.py | 4 ++-- src/transformers/models/markuplm/modeling_markuplm.py | 2 +- src/transformers/models/mask2former/modeling_mask2former.py | 2 +- src/transformers/models/maskformer/modeling_maskformer.py | 2 +- .../models/maskformer/modeling_maskformer_swin.py | 2 +- src/transformers/models/mbart/modeling_mbart.py | 4 ++-- .../models/megatron_bert/modeling_megatron_bert.py | 2 +- src/transformers/models/mistral/modeling_mistral.py | 2 +- src/transformers/models/mobilevit/modeling_mobilevit.py | 2 +- src/transformers/models/mobilevitv2/modeling_mobilevitv2.py | 2 +- src/transformers/models/mpt/modeling_mpt.py | 2 +- src/transformers/models/mra/modeling_mra.py | 2 +- src/transformers/models/mvp/modeling_mvp.py | 4 ++-- src/transformers/models/nezha/modeling_nezha.py | 2 +- src/transformers/models/nllb_moe/modeling_nllb_moe.py | 2 +- .../models/nystromformer/modeling_nystromformer.py | 2 +- src/transformers/models/opt/modeling_opt.py | 2 +- src/transformers/models/owlv2/modeling_owlv2.py | 2 +- src/transformers/models/owlvit/modeling_owlvit.py | 2 +- src/transformers/models/pegasus/modeling_pegasus.py | 4 ++-- src/transformers/models/pegasus_x/modeling_pegasus_x.py | 4 ++-- src/transformers/models/persimmon/modeling_persimmon.py | 2 +- src/transformers/models/pix2struct/modeling_pix2struct.py | 2 +- src/transformers/models/plbart/modeling_plbart.py | 4 ++-- src/transformers/models/prophetnet/modeling_prophetnet.py | 4 ++-- src/transformers/models/qdqbert/modeling_qdqbert.py | 2 +- src/transformers/models/realm/modeling_realm.py | 2 +- src/transformers/models/rembert/modeling_rembert.py | 2 +- src/transformers/models/roberta/modeling_roberta.py | 2 +- .../roberta_prelayernorm/modeling_roberta_prelayernorm.py | 2 +- src/transformers/models/roc_bert/modeling_roc_bert.py | 2 +- src/transformers/models/roformer/modeling_roformer.py | 2 +- src/transformers/models/rwkv/modeling_rwkv.py | 2 +- src/transformers/models/sam/modeling_sam.py | 2 +- .../models/seamless_m4t/modeling_seamless_m4t.py | 4 ++-- src/transformers/models/sew/modeling_sew.py | 4 ++-- src/transformers/models/sew_d/modeling_sew_d.py | 4 ++-- .../models/speech_to_text/modeling_speech_to_text.py | 4 ++-- .../models/speech_to_text_2/modeling_speech_to_text_2.py | 2 +- src/transformers/models/speecht5/modeling_speecht5.py | 6 +++--- src/transformers/models/splinter/modeling_splinter.py | 2 +- src/transformers/models/swin/modeling_swin.py | 2 +- src/transformers/models/swin2sr/modeling_swin2sr.py | 2 +- src/transformers/models/swinv2/modeling_swinv2.py | 2 +- .../models/table_transformer/modeling_table_transformer.py | 2 +- src/transformers/models/tapas/modeling_tapas.py | 2 +- .../modeling_time_series_transformer.py | 4 ++-- src/transformers/models/timesformer/modeling_timesformer.py | 2 +- src/transformers/models/trocr/modeling_trocr.py | 2 +- src/transformers/models/tvlt/modeling_tvlt.py | 4 ++-- src/transformers/models/unispeech/modeling_unispeech.py | 6 +++--- .../models/unispeech_sat/modeling_unispeech_sat.py | 6 +++--- src/transformers/models/videomae/modeling_videomae.py | 4 ++-- src/transformers/models/vilt/modeling_vilt.py | 2 +- src/transformers/models/visual_bert/modeling_visual_bert.py | 2 +- src/transformers/models/vit/modeling_vit.py | 2 +- src/transformers/models/vit_hybrid/modeling_vit_hybrid.py | 2 +- src/transformers/models/vit_mae/modeling_vit_mae.py | 4 ++-- src/transformers/models/vit_msn/modeling_vit_msn.py | 2 +- src/transformers/models/vitdet/modeling_vitdet.py | 2 +- src/transformers/models/vits/modeling_vits.py | 2 +- src/transformers/models/vivit/modeling_vivit.py | 2 +- src/transformers/models/wav2vec2/modeling_wav2vec2.py | 6 +++--- .../wav2vec2_conformer/modeling_wav2vec2_conformer.py | 4 ++-- src/transformers/models/wavlm/modeling_wavlm.py | 6 +++--- src/transformers/models/whisper/modeling_whisper.py | 4 ++-- src/transformers/models/x_clip/modeling_x_clip.py | 4 ++-- src/transformers/models/xglm/modeling_xglm.py | 2 +- .../models/xlm_prophetnet/modeling_xlm_prophetnet.py | 4 ++-- src/transformers/models/xlm_roberta/modeling_xlm_roberta.py | 2 +- .../models/xlm_roberta_xl/modeling_xlm_roberta_xl.py | 2 +- src/transformers/models/xmod/modeling_xmod.py | 2 +- src/transformers/models/yolos/modeling_yolos.py | 2 +- src/transformers/models/yoso/modeling_yoso.py | 2 +- .../modeling_{{cookiecutter.lowercase_modelname}}.py | 6 +++--- 149 files changed, 197 insertions(+), 197 deletions(-) diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index 7b141b5f65a367..58dc2a89200930 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -1096,7 +1096,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 71e650adba1bff..e6229165aace86 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -647,7 +647,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, @@ -956,7 +956,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index 1c79f3cfd78b21..a1f85e2a09eb12 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -337,7 +337,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 29073c3d57dd3e..40e3002310852b 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -1209,7 +1209,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1421,7 +1421,7 @@ def forward( ) use_cache = False layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 11c53ccbdb2192..2708b00d05c49c 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -639,7 +639,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - block.forward, + block.__call__, hidden_states, None, attention_mask, diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 390af1a825a753..73eca72e5d1288 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -856,7 +856,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1106,7 +1106,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 860de96323be6a..3ba3d4911b0fb5 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -511,7 +511,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index b251c9c9b55916..91380e13a05522 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -594,7 +594,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 97fb89e95413d6..123cb2212e19c6 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -402,7 +402,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 890eb8c6875f3e..0ba2119e684492 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -1618,7 +1618,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 03ef911970ad87..98ff51032bad5e 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1945,7 +1945,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -2285,7 +2285,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 7dc72aa6368ecd..2bbdbed348a196 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -592,7 +592,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, head_mask[idx] if head_mask is not None else None, diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 35879ac1500a97..51a947af0a8324 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -779,7 +779,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1028,7 +1028,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 59ba6b9dd874b8..88a9b52de90966 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -777,7 +777,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1025,7 +1025,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 927c33f9927c08..efd986299c292b 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -624,7 +624,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, output_attentions, diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index a9decd052d375d..e0aa4e17f146fe 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -423,7 +423,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 735b81bc4229e1..2f7f00b3dd5970 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -479,7 +479,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, output_attentions, @@ -944,7 +944,7 @@ def forward( ) use_cache = False layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 83998421e131b0..583367c9ab5571 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -763,7 +763,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - block.forward, + block.__call__, hidden_states, alibi, causal_mask, diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index ea4c3cc285badd..0f272a21e21da7 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -805,7 +805,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index 60e753c95f8de3..c10f8350567611 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -652,7 +652,7 @@ def forward( ) use_cache = False layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, bbox_pos_emb, attention_mask, diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index d5d9f0ae488f20..2e0a6c12fe6463 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -525,7 +525,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index 198e3376731adc..adc875910320eb 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -796,7 +796,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index a010d82fd9de5e..ef1c265723b6f0 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -911,7 +911,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, @@ -1015,7 +1015,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, output_attentions, ) diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 7c6c9618c4536e..025b59ae4b9743 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -940,7 +940,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, hidden_states, input_dimensions, layer_head_mask, output_attentions + layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( @@ -1589,7 +1589,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 9e179753157baf..56f24c157f831c 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -641,7 +641,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index 0bded11f9bc1da..7a0e5292698393 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -650,7 +650,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 0a01e05044e47e..340719e1fb7880 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -542,7 +542,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - block.forward, + block.__call__, hidden_states, None, attention_mask, diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index c887b170c9cdcd..01dbf8ecd59cc7 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -1520,7 +1520,7 @@ def forward( query_sine_embed = query_sine_embed_before_transformation * pos_transformation if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, combined_attention_mask, object_queries, diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index e240830214253b..da577a58961430 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -634,7 +634,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 5a2491571efaca..a99b6f3a6dc384 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -294,7 +294,7 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - conv_layer.forward, + conv_layer.__call__, hidden_states, ) else: @@ -587,7 +587,7 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, output_attentions, diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index ba5c6b97a965d9..507c2fc464d8b7 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -511,7 +511,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index 6c5c39e4957c52..2742d5ffc37bbe 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -523,7 +523,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index a7816bae558bec..65ec497cecd8c0 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -458,7 +458,7 @@ def forward( if self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, next_kv, attention_mask, query_states, diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index e536d376c59107..2245ac549ada94 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -502,7 +502,7 @@ def forward( if self.gradient_checkpointing and self.training: output_states = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, next_kv, attention_mask, query_states, diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 8146436cdc5186..19c2731a50a745 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -633,7 +633,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - block.forward, + block.__call__, hidden_states, None, attention_mask, diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index b33ba3a5fa2a23..220fcf0d066003 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -1385,7 +1385,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, encoder_hidden_states, encoder_attention_mask, diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 4cd8785ce535e0..6e97e932b533a7 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -358,7 +358,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/deprecated/mctct/modeling_mctct.py b/src/transformers/models/deprecated/mctct/modeling_mctct.py index 779b409470d920..9e7a73c5880be7 100755 --- a/src/transformers/models/deprecated/mctct/modeling_mctct.py +++ b/src/transformers/models/deprecated/mctct/modeling_mctct.py @@ -618,7 +618,7 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 80f27e4d666cc9..fb1cc7f0fb84d5 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -667,7 +667,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, position_ids, diff --git a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py index 8081a96430bcea..c9f31c714446e6 100644 --- a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py +++ b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py @@ -552,7 +552,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - block.forward, + block.__call__, hidden_states, layer_past, use_cache, diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index 9e0954736963e5..a6f979eaeea6e1 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -1277,7 +1277,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, encoder_hidden_states, encoder_attention_mask, diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 7781298b0137a9..1c09e3e3d7b213 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -1255,7 +1255,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index 1fd39703bce305..1440b6d615fb4a 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -448,7 +448,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index db48ac56fee310..3768dd6e91ca7e 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -359,7 +359,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_state, attn_mask, head_mask[i], diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index a789b7ef57ba46..76d525717f8cf3 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -750,7 +750,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, hidden_states, input_dimensions, layer_head_mask, output_attentions + layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 513892740ed7c6..2621fa338015a8 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -529,7 +529,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index eee30624719ecf..fde5632c09c3f2 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -572,7 +572,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index d88563e778c790..330cb5033160e5 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -507,7 +507,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 21e480c8212b93..86bd20a46480d0 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -606,7 +606,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 5c011ddf3d5c60..642e60a72f91df 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1280,7 +1280,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - block.forward, + block.__call__, hidden_states, alibi, attention_mask, diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 61431463215194..1fbf49f9e127e6 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -664,7 +664,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py index f9ec022845f065..b84761536bac46 100755 --- a/src/transformers/models/fnet/modeling_fnet.py +++ b/src/transformers/models/fnet/modeling_fnet.py @@ -292,7 +292,7 @@ def forward(self, hidden_states, output_hidden_states=False, return_dict=True): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func(layer_module.forward, hidden_states) + layer_outputs = self.gradient_checkpointing_func(layer_module.__call__, hidden_states) else: layer_outputs = layer_module(hidden_states) diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index 5ff1c99b94f3da..87ec98169626b8 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -587,7 +587,7 @@ def forward( for i, stage_module in enumerate(self.stages): if self.gradient_checkpointing and self.training: stage_outputs = self.gradient_checkpointing_func( - stage_module.forward, + stage_module.__call__, hidden_states, input_dimensions, ) diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 0e44931eb99ea1..293b9c789d56b9 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -453,7 +453,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, @@ -875,7 +875,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index dc28ed3640f472..24826a76bc0446 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -879,7 +879,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - block.forward, + block.__call__, hidden_states, None, attention_mask, diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 7d4e77a4674f6a..37c51b40c9a78d 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -652,7 +652,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - block.forward, + block.__call__, hidden_states, None, attention_mask, diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 6ede0829cd03a5..ed1e62bf175fb1 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -606,7 +606,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - block.forward, + block.__call__, hidden_states, None, attention_mask, diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 860552cde48527..cf0aa0645ae07f 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -643,7 +643,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, position_ids, diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index c0302b6c21a044..65f805b71716aa 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -669,7 +669,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - block.forward, + block.__call__, hidden_states, None, attention_mask, diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index 332b14d9961cb7..a9de67143846b7 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -1032,7 +1032,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index e5b1b1742e74ba..732e6be2f8ddd1 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -347,7 +347,7 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - conv_layer.forward, + conv_layer.__call__, hidden_states, ) else: @@ -725,7 +725,7 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, output_attentions, @@ -809,7 +809,7 @@ def forward( # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, output_attentions, diff --git a/src/transformers/models/idefics/vision.py b/src/transformers/models/idefics/vision.py index cb604909e1927c..24dc3e9396aa79 100644 --- a/src/transformers/models/idefics/vision.py +++ b/src/transformers/models/idefics/vision.py @@ -402,7 +402,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 187f39248fbc8e..a365731ed53d0e 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -818,7 +818,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - block.forward, + block.__call__, hidden_states, None, attention_mask, diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 423de7d819769a..53518760cc003b 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -1217,7 +1217,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1434,7 +1434,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 3cc44efbe3618a..d4cb7a1fa00bcf 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -468,7 +468,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, output_attentions, @@ -939,7 +939,7 @@ def forward( ) use_cache = False layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index dc094bd8ba0bff..ce6d4302bccc19 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -488,7 +488,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index 03900bff907c91..8f6260fdda4931 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -440,7 +440,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 9afc855417fabd..e387707e52da4f 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -658,7 +658,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 3d4e3c26188c57..61bbd4156b4603 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -1878,7 +1878,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, head_mask[idx] if head_mask is not None else None, @@ -2139,7 +2139,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index 2c7085aa822821..4fd7a85affd76c 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -515,7 +515,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layout_inputs, attention_mask, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8f55982565ce98..279884dc164f04 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1015,7 +1015,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, position_ids, diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 6ca8f61cfa4ca1..b4f20b4525585e 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -1305,7 +1305,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, head_mask[idx] if head_mask is not None else None, diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index 143932f924bf6a..3b5f4d0bf71dc0 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -789,7 +789,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, word_hidden_states, entity_hidden_states, attention_mask, diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index b9b672ca28291c..4ebe11f3f3b3bc 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -822,7 +822,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1062,7 +1062,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 81a4d7b6f6b527..e2e09b564b0e48 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -790,7 +790,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1033,7 +1033,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index fc15c86e7a9460..80498efb3cadd2 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -649,7 +649,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index 7d00b6b6d87127..86eccc478753f3 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -1865,7 +1865,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index a941c0508a94f0..7df8b60792a054 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -849,7 +849,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index dd6c45de8a56b0..89c6a0c0e0b4c1 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -689,7 +689,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_hidden_states, output_dimensions, layer_all_hidden_states = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 341260efe45cbb..7c4c9bdf959801 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -830,7 +830,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1082,7 +1082,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index a2e2a39ec966ab..c23666f10b725f 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -552,7 +552,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index a95215f7641a6c..36b5a4b66bb5a8 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1021,7 +1021,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, position_ids, diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index 1e8a8afa07ddd0..c664c02a883ba0 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -627,7 +627,7 @@ def forward( for i, layer_module in enumerate(self.layer): if self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, ) else: diff --git a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py index c857915a8cca99..b88925f41b83c6 100644 --- a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py +++ b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py @@ -583,7 +583,7 @@ def forward( for i, layer_module in enumerate(self.layer): if self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, ) else: diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index 897a90ce0486a1..ede306e71b867e 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -525,7 +525,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - block.forward, + block.__call__, hidden_states, alibi, causal_mask, diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index 1da9da2af9159f..f6cb65889a371c 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -767,7 +767,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, ) diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index d8622fca958264..122b49287872ae 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -951,7 +951,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1223,7 +1223,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/nezha/modeling_nezha.py b/src/transformers/models/nezha/modeling_nezha.py index a8ad52d2698831..cd43688e3f741a 100644 --- a/src/transformers/models/nezha/modeling_nezha.py +++ b/src/transformers/models/nezha/modeling_nezha.py @@ -578,7 +578,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 51bbd56d2b58cc..cbed1e1b153011 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -1155,7 +1155,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index 9a023cbc91ef36..9b2052eb6ca4ae 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -371,7 +371,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, output_attentions, diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 5782d796566a04..9925e7b4a46b4a 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -693,7 +693,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, causal_attention_mask, head_mask[idx] if head_mask is not None else None, diff --git a/src/transformers/models/owlv2/modeling_owlv2.py b/src/transformers/models/owlv2/modeling_owlv2.py index 351a1a77d59a7e..a1491d15ea5527 100644 --- a/src/transformers/models/owlv2/modeling_owlv2.py +++ b/src/transformers/models/owlv2/modeling_owlv2.py @@ -766,7 +766,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index 63e1570a110697..68037d13950ed6 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -755,7 +755,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 5fc671f25f4684..058ecd1775a9bf 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -805,7 +805,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1083,7 +1083,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index f35bef20969d30..6eaddf642a8b6b 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -1073,7 +1073,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, global_hidden_states, attention_mask, @@ -1327,7 +1327,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index fda50ca47690ec..8043fc8699a655 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -670,7 +670,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, position_ids, diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 9b4444c56ee294..cfc2b137c579cf 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -344,7 +344,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index cdd73be66d7ae3..1e047fd372676e 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -809,7 +809,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1060,7 +1060,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index eb1b319fb19a4a..e4c28659cb489b 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -1331,7 +1331,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, extended_attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1565,7 +1565,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, extended_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py index 69be03b93bded0..0a2546a9b64e86 100755 --- a/src/transformers/models/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/qdqbert/modeling_qdqbert.py @@ -582,7 +582,7 @@ def forward( ) use_cache = False layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py index a63e3a9e9bce6f..86b37b21560ba4 100644 --- a/src/transformers/models/realm/modeling_realm.py +++ b/src/transformers/models/realm/modeling_realm.py @@ -587,7 +587,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 6471653da7bf74..e5e662a9b5564d 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -544,7 +544,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index aedfc5ef807780..32a19c08831777 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -511,7 +511,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 1bcdb872451889..78ca20684540bd 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -513,7 +513,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 3627944fab4b95..3a58efa9140c00 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -645,7 +645,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 6773a6f967adb2..3893e27b028f2e 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -579,7 +579,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, sinusoidal_pos, diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index d7c7df9a839002..275233321372a2 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -678,7 +678,7 @@ def forward( for idx, block in enumerate(self.blocks): if self.gradient_checkpointing and self.training: hidden_states, state, attentions = self.gradient_checkpointing_func( - block.forward, hidden_states, state, use_cache, output_attentions + block.__call__, hidden_states, state, use_cache, output_attentions ) else: hidden_states, state, attentions = block( diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index d384747af33653..1bd6fcdc2a8f14 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -1043,7 +1043,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, ) else: diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index a930d60ec9da4c..ea79c734188391 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -893,7 +893,7 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, relative_position_embeddings, @@ -2119,7 +2119,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 883fab34fce208..36416c168c3600 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -361,7 +361,7 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - conv_layer.forward, + conv_layer.__call__, hidden_states, ) else: @@ -667,7 +667,7 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, output_attentions, diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 2dc2231e607335..39c9641b9489fe 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -454,7 +454,7 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - conv_layer.forward, + conv_layer.__call__, hidden_states, ) else: @@ -1128,7 +1128,7 @@ def forward( if self.gradient_checkpointing and self.training: output_states = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, next_kv, attention_mask, query_states, diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 030358ff033a1d..73a02fe66df7ba 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -819,7 +819,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1061,7 +1061,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index 486dda2f46b4c7..acee2b15a44f2c 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -671,7 +671,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 40d30f366a2066..b8fea796647b70 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -521,7 +521,7 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - conv_layer.forward, + conv_layer.__call__, hidden_states, ) else: @@ -1381,7 +1381,7 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1701,7 +1701,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index d766f435f15010..1bdf8f3f5f9194 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -460,7 +460,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 25432478abeaf1..c2f15dbbf27394 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -826,7 +826,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, hidden_states, input_dimensions, layer_head_mask, output_attentions + layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py index d7b248b1135990..47ce01d1691668 100644 --- a/src/transformers/models/swin2sr/modeling_swin2sr.py +++ b/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -747,7 +747,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - stage_module.forward, hidden_states, input_dimensions, layer_head_mask, output_attentions + stage_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = stage_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index c00ae39e0bec2d..6daad938a623ab 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -907,7 +907,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, hidden_states, input_dimensions, layer_head_mask, output_attentions + layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index fc9f001ff0603c..e1da557b0017cc 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -1151,7 +1151,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index ae22bbd8449d2c..de05d77ec94358 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -647,7 +647,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 349bc5d48adfe2..1fa6a963f58f50 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -948,7 +948,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1159,7 +1159,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/timesformer/modeling_timesformer.py b/src/transformers/models/timesformer/modeling_timesformer.py index ccc65287cdc20a..044705c35e5410 100644 --- a/src/transformers/models/timesformer/modeling_timesformer.py +++ b/src/transformers/models/timesformer/modeling_timesformer.py @@ -440,7 +440,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, output_attentions, ) diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 9b7fab8e2f3d4d..ada8638a03b608 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -703,7 +703,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/tvlt/modeling_tvlt.py b/src/transformers/models/tvlt/modeling_tvlt.py index 086cf66fd40dda..a37265f37c7ad8 100644 --- a/src/transformers/models/tvlt/modeling_tvlt.py +++ b/src/transformers/models/tvlt/modeling_tvlt.py @@ -561,7 +561,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, @@ -873,7 +873,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, None, output_attentions, diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index bcfc4069c8a35f..db14d5bca51f5c 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -385,7 +385,7 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - conv_layer.forward, + conv_layer.__call__, hidden_states, ) else: @@ -761,7 +761,7 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, output_attentions, @@ -845,7 +845,7 @@ def forward( # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, output_attentions, diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 778dbfad18a9e0..8a9a63804b56a4 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -399,7 +399,7 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - conv_layer.forward, + conv_layer.__call__, hidden_states, ) else: @@ -775,7 +775,7 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, output_attentions, @@ -859,7 +859,7 @@ def forward( # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, output_attentions, diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index 84ff258c58b812..277280954fd6f1 100644 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -435,7 +435,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, @@ -722,7 +722,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, None, output_attentions, diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index a93dc99903e1bb..482bd08359bd4a 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -532,7 +532,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index 36a1292fc9fdb0..425a125a0b89dd 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -419,7 +419,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index b06ab62113a745..67dbddf8766a41 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -398,7 +398,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py index 7b54e6c1535b3f..959522843f7a67 100644 --- a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py +++ b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py @@ -416,7 +416,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index 5fa10ca9d1376c..e156fdc3292c4b 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -537,7 +537,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, @@ -789,7 +789,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, None, output_attentions, diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index 91e13c7b6adc9d..b727c331cfb4d7 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -388,7 +388,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/vitdet/modeling_vitdet.py b/src/transformers/models/vitdet/modeling_vitdet.py index 8e20f17e070920..9bb3991fabf186 100644 --- a/src/transformers/models/vitdet/modeling_vitdet.py +++ b/src/transformers/models/vitdet/modeling_vitdet.py @@ -566,7 +566,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/vits/modeling_vits.py b/src/transformers/models/vits/modeling_vits.py index 49b8e1a6a40a18..b621bde35e61da 100644 --- a/src/transformers/models/vits/modeling_vits.py +++ b/src/transformers/models/vits/modeling_vits.py @@ -1168,7 +1168,7 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, padding_mask, attention_mask, diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index b4ed99bd9e98a1..50cb82fb4e188f 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -339,7 +339,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index ec38d6a11570ff..9f48e529627e8e 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -452,7 +452,7 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - conv_layer.forward, + conv_layer.__call__, hidden_states, ) else: @@ -797,7 +797,7 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, output_attentions, @@ -880,7 +880,7 @@ def forward( # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, output_attentions, diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index 5d723592556843..5fba773ee0cb4f 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -519,7 +519,7 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - conv_layer.forward, + conv_layer.__call__, hidden_states, ) else: @@ -905,7 +905,7 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, relative_position_embeddings, diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index ef76b43330890b..55b19e4c414341 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -355,7 +355,7 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - conv_layer.forward, + conv_layer.__call__, hidden_states, ) else: @@ -707,7 +707,7 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, position_bias, @@ -792,7 +792,7 @@ def forward( # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, position_bias, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index c868abe44c0edc..d6d0302727cb09 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -944,7 +944,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, None, (head_mask[idx] if head_mask is not None else None), @@ -1170,7 +1170,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index 46ad1fb719e7cb..6c9cc02db9c83f 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -705,7 +705,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, @@ -946,7 +946,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 075e0c3159704c..1880a7832193a8 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -676,7 +676,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py index bc86fd6ff5c8fb..9a9f02b74a65a4 100644 --- a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -1351,7 +1351,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, extended_attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1588,7 +1588,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, extended_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 1bc22ca1004580..da99b2806fb6f8 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -512,7 +512,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 3477d709ae0e33..49f7c07517211c 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -500,7 +500,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 4ca4adeec995bd..5f7b42f266fb1e 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -574,7 +574,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, lang_ids, attention_mask, diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index a378e96f9909c2..f6cbaecd014e61 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -493,7 +493,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index b0cbd589b293b8..8db66d22106160 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -562,7 +562,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, output_attentions, diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index e2e36ac3682308..0b5af845c9aaa3 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -545,7 +545,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, @@ -2311,7 +2311,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -2544,7 +2544,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, From cec602a669f973e7d8cd790986ac4071a566ddf5 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 25 Oct 2023 09:43:31 +0000 Subject: [PATCH 11/12] add comment --- src/transformers/modeling_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index aabb8e34cbb092..73255b021f5f2a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1856,6 +1856,9 @@ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint activations". + We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks + of the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + Args: gradient_checkpointing_kwargs (dict, *optional*): Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. From ded70d40d071f574d060f913431100c6f7be5fe4 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 25 Oct 2023 09:52:45 +0000 Subject: [PATCH 12/12] quality --- src/transformers/modeling_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 73255b021f5f2a..47e9cb2f23e0f1 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1856,8 +1856,8 @@ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint activations". - We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks - of the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of + the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 Args: gradient_checkpointing_kwargs (dict, *optional*):