-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make gradient_checkpointing a training argument #13657
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fantastic! Thanks so much for adding this feature and making it independent from tweaking the config object. Loving it!
Left a few small suggestions.
src/transformers/modeling_utils.py
Outdated
Will activate gradient checkpointing if :obj:`True`, deactivate it if :obj:`False`. | ||
""" | ||
if not self.supports_gradient_checkpointing and flag: | ||
logger.warn(f"{self.__class__.__name__} does not support gradient checkpointing so nothing will happen.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason not to assert here instead? The user can then change their setup and proceed without problems.
It's a clear error to activate this option if a model doesn't support it, IMHO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's to be consistent with the previous behavior where we did nothing if the user input gradient_checkpointing
for a model that did not support it.
I'm not opposed to asserting, but let's see what @LysandreJik and @patrickvonplaten think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would also be in favor of raising an error here actually. It's a new function so I think we can add this behavior here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will switch then!
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! Thanks for taking care of all the mentions of gradient_checkpointing
in the repository, very cool work!
- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by setting | ||
``config.gradient_checkpointing = True``. | ||
- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by executing | ||
``model.gradient_checkpointing_enable()``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about enable_gradient_checkpointing
?
@@ -932,6 +933,21 @@ def prune_heads(self, heads_to_prune: Dict[int, List[int]]): | |||
|
|||
self.base_model._prune_heads(heads_to_prune) | |||
|
|||
def gradient_checkpointing_enable(self, flag: bool = True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should there be a disable too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I didn't see this had a flag! Maybe toggle
then? Or set_gradient_checkpointing
to follow traditional boolean setter conventions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@stas00 really wanted the method name to start with gradient_checkpointing
to be more easily discoverable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After some discussion, with Lysandre, we decided to try gradient_checkpointing_enable
and gradient_checkpointing_disable
(no args for each).
I took the liberty to also document this feature in https://huggingface.co/transformers/performance.html and pushed it here, so if you rename the method please adjust the doc as well. Thank you! |
I'm not very happy about keeping I'm very much in favor of removing |
I am not following since this is all private. The user does not have to know anything about model configurations for this option. I'm also not sure which new exceptions you are mentioning?
Note that those submodules are often not even In any case, if this second approach is selected, I would still urge to merge this PR as soon as possible to avoid any merge conflict or many user diverging from the templates. We can then change the internal implementation on the models added more progressively. |
I'm just a bit worried that we'll start using the "private" configuration parameters of For a user that just looks at the configuration on the hub this PR is great, but for users that actually looks into the code, adding a Think we should be able to add a single method to the def _enable_gradient_checkpointing(self):
model = self
if hasattr(model, self.base_model_prefix):
model = getattr(model, self.base_model_prefix)
# set gradient checkpointing to True in the encoder
model.encoder.gradient_checkpointing = True => this should work just fine no? Given that we will have to leave it in the config anyways until v5, I'm fine with leveraging the config I guess - I just don't think it's good practice to introduce "special" configuration parameters with |
If we leave the config as is, as proposed by Patrick, should we perhaps discuss the ability for the user to choose what goes into the published model's config? We are sort of trying to do DWIM (do what I mean) and magically have the published model have all the right settings. So adding to the model saving interface our default filters which for example will automatically disable In the current PR the user has no control over And we won't need to wait till v5 to do so. |
@stas00 This is out of scope of this PR (which does not contain the |
I was just following up to Patrick's comment. I have no problem with not discussing it here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the extra effort! Really like the new design
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! Thank you for iterating.
* Make gradient_checkpointing a training argument * Update src/transformers/modeling_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Update src/transformers/configuration_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Fix tests * Style * document Gradient Checkpointing as a performance feature * Small rename * PoC for not using the config * Adapt BC to new PoC * Forgot to save * Rollout changes to all other models * Fix typo Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Stas Bekman <stas@stason.org>
* Make gradient_checkpointing a training argument * Update src/transformers/modeling_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Update src/transformers/configuration_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Fix tests * Style * document Gradient Checkpointing as a performance feature * Small rename * PoC for not using the config * Adapt BC to new PoC * Forgot to save * Rollout changes to all other models * Fix typo Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Stas Bekman <stas@stason.org>
* Make gradient_checkpointing a training argument * Update src/transformers/modeling_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Update src/transformers/configuration_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Fix tests * Style * document Gradient Checkpointing as a performance feature * Small rename * PoC for not using the config * Adapt BC to new PoC * Forgot to save * Rollout changes to all other models * Fix typo Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Stas Bekman <stas@stason.org>
* Make gradient_checkpointing a training argument * Update src/transformers/modeling_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Update src/transformers/configuration_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Fix tests * Style * document Gradient Checkpointing as a performance feature * Small rename * PoC for not using the config * Adapt BC to new PoC * Forgot to save * Rollout changes to all other models * Fix typo Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Stas Bekman <stas@stason.org>
It uses `flax.linen.remat` and follows on PRs huggingface#13657 and huggingface#17994
What does this PR do?
This PR reworks the logic behind gradient accumulation. It is currently set as a configuration argument which is annoying because:
That's why this PR depractes the
gradient_checkpointing
argument in any config and adds:gradient_checkpointing_enable
toPreTrainedModel
to activate gradient checkpointingTrainer
API that will call thatgradient_checkpointing
method.Internally, the implementation still relies on the config as it's the easiest place to set something that needs to pass several layers of a model (if we have a
BertForMaskedLM
for instance, the actual gradient checkpointing only applies to theBertEncoder
inside theBertModel
inside thatBertForMaskedLM
) but that argument is made private and not saved to the model Hub.