-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for gradient checkpointing in BERT #4659
Conversation
Codecov Report
@@ Coverage Diff @@
## master #4659 +/- ##
==========================================
- Coverage 78.40% 78.06% -0.35%
==========================================
Files 138 138
Lines 23757 23766 +9
==========================================
- Hits 18627 18552 -75
- Misses 5130 5214 +84
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a great addition, of which all models could benefit. Let's see what the rest of the team thinks and we'll look at upstreaming it in the transformers.PretrainedModel
if everyone's on board.
Thanks, @LysandreJik. It would be great to make |
I was thinking of having the implementation be model agnostic as well. I haven't really thought out the best way, but a possible way to achieve it would be with a decorator; for example, in @staticmethod
def gradient_checkpointing(layer):
@functools.wraps(layer)
def wrapper(*args):
layer_instance = args[0]
# Remove the wrapper to prevent infinite recursion on the wrapper
layer_instance.forward = functools.partial(layer_instance.forward.__wrapped__, layer_instance)
if args[0].config.gradient_checkpointing:
return torch.utils.checkpoint.checkpoint(layer_instance, *args[1:])
else:
return layer(*args)
return wrapper Then we can very simply add that decorator on the layers where we want the checkpoint: class BertLayer(nn.Module):
...
@PreTrainedModel.gradient_checkpointing
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
):
... This would require that these layers have access to the configuration so that they're aware of gradient check-pointing or not. Pretty convenient, but pretty different from our coding style as well cc @thomwolf |
neat |
A model agnostic approach might be best. In my research for isolating minimaxir/aitextgen#6 for finetuning larger GPT-2 models, it appeared that checkpointing would have to be implemented at the model level, as this PR does for BERT. |
torch.utils.checkpoint.checkpoint works well in single GPU. But it causes OOM in multi-gpu with torch.nn.DataParallel. |
I like idea of having a decorator function! Would it be enough to have this wrapper only at all "Model" forward functions, like |
I haven't tried |
I don't think so. Even with the decorator, it is still model-specific; the decorator just makes the syntax easier. You still need to decide where to call it because too few calls (e.g. only on |
Pinging @julien-c so he can take a look. |
Thanks for the advice. But I try The code is:
Both |
@ibeltagy, after some back and forth offline with @julien-c and @thomwolf, the way you implemented it is preferred as it's simpler to understand and adheres better to the library's philosophy. I think we can merge this and then in a following PR add it to all the other models. Would you like to take care of that? No worries if not, I can definitely take care of it. |
@LysandreJik, glad this will be merged.
I will pass :D |
I encounter the same issue with torch 1.5.0 and latest transformers |
@ewrfcas, @schinger, do you have a small example that reproduces the error? I don't think we can fix this issue (needs a PyTorch fix pytorch/pytorch#24005), but I think we can work around it by removing the unused parameters mentioned in the error message. |
squad example training can reproduce this error: https://github.com/huggingface/transformers/tree/master/examples/question-answering python -m torch.distributed.launch --nproc_per_node=8 ./examples/question-answering/run_squad.py no matter find_unused_parameters is ture or false |
Thanks. It would be more helpful if you provide a simpler and smaller example that I can easily run. |
you can change --train_file to SQUAD_DIR/dev-v1.1.json, dev set is small for quickly run |
could you show me a example about gradient checkpoint works successfully with |
I have trained a base model instead of large to delay this problem.
and
Other codes are the same as normal finetuning codes. |
Here's a small example to replicate the error
Use
Use I couldn't replicate the other error,
@ewrfcas, do you know how to modify the example above to reproduce it? @schinger, can you try |
I have tried this code. Although it works in the first, the second forword will be failed. You can try to repeat the loss.backward for few times. |
@ewrfcas, I get this error with |
I have solved this problem by removing the self.pooler layer in BertModel because it did not forward any thing during the training. As the error saied, all parameters should be used in loss for DistributedDataParallel with find_unused_parameters=False, and find_unused_parameters=True is incompatible with gradient_checkpointing. Maybe we need this code after the first backward
|
Nice finding, @ewrfcas. @LysandreJik, what is the best way to address this problem? do we need to fix it or can we leave it to the user to make sure all the model params are used? maybe in a separate PR we can find a way to automatically remove unused model params? Also, aside from this issue, anything else we need to merge the PR? |
Right, I think this should be looked at in a separate PR. Will take a final look and merge this PR tomorrow, and then look at implementing gradient checkpointing to the rest of the models. |
This PR adds support for gradient checkpointing in
modeling_bert.py
to save memory at training time at the expense of a slower backward pass. This is particularly useful if we want to pretrain a version of BERT for sequences longer than 512. It is also useful for long-document models like Longformer.Stats: