Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient checkpointing throws use_reentrant warning on PyTorch 2.1 #28536

Closed
2 of 4 tasks
rosario-purple opened this issue Jan 16, 2024 · 15 comments · Fixed by #28538 or #33208
Closed
2 of 4 tasks

Gradient checkpointing throws use_reentrant warning on PyTorch 2.1 #28536

rosario-purple opened this issue Jan 16, 2024 · 15 comments · Fixed by #28538 or #33208

Comments

@rosario-purple
Copy link

System Info

  • transformers version: 4.36.2
  • Platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.35
  • Python version: 3.10.13
  • Huggingface_hub version: 0.19.4
  • Safetensors version: 0.4.0
  • Accelerate version: 0.25.0
  • Accelerate config: - compute_environment: LOCAL_MACHINE
    - distributed_type: DEEPSPEED
    - mixed_precision: bf16
    - use_cpu: False
    - debug: False
    - num_processes: 8
    - machine_rank: 0
    - num_machines: 1
    - rdzv_backend: static
    - same_network: True
    - main_training_function: main
    - deepspeed_config: {'gradient_accumulation_steps': 1, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': True, 'zero3_save_16bit_model': False, 'zero_stage': 3}
    - downcast_bf16: no
    - tpu_use_cluster: False
    - tpu_use_sudo: False
    - tpu_env: []
  • PyTorch version (GPU?): 2.1.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): 0.7.5 (cpu)
  • Jax version: 0.4.21
  • JaxLib version: 0.4.21
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: Yes

Who can help?

@ArthurZucker @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Training any text model with gradient checkpointing enabled on PyTorch 2.1 and higher produces this warning:

/scratch/miniconda3/envs/brr/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: Warning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.

This can be resolved by manually monkey-patching the model code with use_reentrant=True, eg. like so:

                hidden_states, self_attns, decoder_cache = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(decoder_layer),
                    hidden_states,
                    attention_mask,
                    position_ids,
                    None,
                    is_padded_inputs,
                    use_reentrant=True,
                )

This is caused by an upstream change in PyTorch:

https://medium.com/pytorch/how-activation-checkpointing-enables-scaling-up-training-deep-learning-models-7a93ae01ff2d

Expected behavior

No warning should be written

@ArthurZucker
Copy link
Collaborator

Thanks for raising! given that we had #27020, this should be fairly easy to fix! cc @younesbelkada

@rosario-purple
Copy link
Author

@ArthurZucker is this still outstanding?

@ArthurZucker
Copy link
Collaborator

Will merge the PR today

@lucasjinreal
Copy link

Which version start this fixed? Am using 3.47.2 still get this error.

@huangganggui
Copy link

huangganggui commented Apr 11, 2024

4.39.3 till get this warning.

@huangganggui
Copy link

4.39.3 till get this warning.

For my case, model.gradient_checkpointing_enable() fix it. maybe you can try @lucasjinreal

@ankush13r
Copy link
Contributor

I'm using transformers==4.43.3, and still getting errors when trying to use the Trainer API with gradient_checkpointing=True.

@BigDataMLexplorer
Copy link

BigDataMLexplorer commented Aug 11, 2024

I'm using transformers==4.43.3, and still getting errors when trying to use the Trainer API with gradient_checkpointing=True.

Me too.. Try to use model.gradient_checkpointing_enable() and do not specify gradient_checkpointing=True in huggingface Trainer API. It solved my problem.

@ArthurZucker
Copy link
Collaborator

Could you all share which model you are using? 🤗

@BigDataMLexplorer
Copy link

Hi, I use Llama3 8b.

@ankush13r
Copy link
Contributor

Hello, I'm using Llama-2-7b and Mistral-7B-v0.3. Both are giving same warning.

@ArthurZucker
Copy link
Collaborator

Are you using a recent version of transformers? By default we do pass this flag:

def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):

so something like model.gradient_checkpointing_enable({"use_reentrant": False}). But by default we already pass this flag when gradien checkpointing is used.

@ankush13r
Copy link
Contributor

Thanks it will solve my problem, Since I'm using trainer and I can pass this argument.
But I think the issue is here:
In trainer It assign the value to gradient_checkpointing_kwargs = {} https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L2122
So when it reaches to the gradient_checkpointing_enable

if gradient_checkpointing_kwargs is None:
function it is not None, it has an empty dict.
The problem will be solved removing this if https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L2121

@ArthurZucker
Copy link
Collaborator

Yep good catch! Do you want to open a PR for this? 🤗

@ankush13r
Copy link
Contributor

Of course, it would be a pleasure to collaborate.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
6 participants