Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reward modelling example throws RuntimeError: Expected to mark a variable ready only once. when gradient_checkpointing=True #831

Closed
lewtun opened this issue Oct 4, 2023 · 5 comments

Comments

@lewtun
Copy link
Member

lewtun commented Oct 4, 2023

Running the reward_model.py example on multiple GPUs with gradient checkpointing is throwing the following error:

Traceback (most recent call last):
  File "/fsx/lewis/git/trl/examples/scripts/reward_trainer.py", line 169, in <module>
    trainer.train()
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 1591, in train
    return inner_training_loop(
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 1892, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 2787, in training_step
    self.accelerator.backward(loss)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/accelerate/accelerator.py", line 1985, in backward
    loss.backward(**kwargs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/autograd/function.py", line 274, in apply
    return user_fn(self, *args)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 157, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 386 has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration. You can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print parameter names for further debugging.

To reproduce run:

ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml examples/scripts/reward_trainer.py 
Copy link

github-actions bot commented Nov 3, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@younesbelkada
Copy link
Contributor

It is now fixed on transformers + peft + trl main, you just need to pass gradient_checkpointing_kwargs={"use_reentrant": False}

@wei-ann-Github
Copy link

It is now fixed on transformers + peft + trl main, you just need to pass gradient_checkpointing_kwargs={"use_reentrant": False}

Hi, where do we pass this argument? I am facing this issue when using SFTTrainer.

@younesbelkada
Copy link
Contributor

Hi @wei-ann-Github
Pass that argument to TrainingArguments note however you need the latest transformers pip install -U transformers

@cxjtju
Copy link

cxjtju commented Apr 28, 2024

Hi @wei-ann-Github Pass that argument to TrainingArguments note however you need the latest transformers pip install -U transformers

what's the transformers version? when transformers == 4.32.0, i encoutered dataclasses.FrozenInstanceError: cannot assign to field gradient_checkpointing_kwargs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants