-
Notifications
You must be signed in to change notification settings - Fork 27.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core
] Refactor of gradient_checkpointing
#27020
Conversation
The documentation is not available anymore as the PR was closed or merged. |
@@ -951,11 +951,6 @@ class TFSwinPreTrainedModel(TFPreTrainedModel): | |||
config_class = SwinConfig | |||
base_model_prefix = "swin" | |||
main_input_name = "pixel_values" | |||
supports_gradient_checkpointing = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I removed it because not relevant to TF models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice cleanup!
@@ -1845,7 +1858,7 @@ def gradient_checkpointing_disable(self): | |||
activations". | |||
""" | |||
if self.supports_gradient_checkpointing: | |||
self.apply(partial(self._set_gradient_checkpointing, value=False)) | |||
self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=None)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WHen we disable gradient checkpointing, I think the module.gradient_checkpointing
will still be True
.
Let's make module.gradient_checkpointing
into a property to be sure we always check if the function is none or not WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Property could go at the ModelMixin level ?
Can you add a test to make sure setting and unsetting both work as expected (specifically for the fix we are implementing in TRL) |
+1 |
# Enable / disable GC for the language model as well | ||
if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"): | ||
self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BLIP2 never propagated gradient_checkpointing
to its language_model
# Enable / disable GC for the language model as well | ||
if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"): | ||
self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
for backbone_module in module.modules(): | ||
if hasattr(backbone_module, "gradient_checkpointing"): | ||
backbone_module.gradient_checkpointing_func = gradient_checkpointing_func | ||
backbone_module.gradient_checkpointing = gradient_checkpointing_func is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another edge case here where the backbone has some modules that support GC but that attribute never being propagated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It turns out ~30 architectures were not properly using gradient_checkpointing
, I left 3 comments to be aware of
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should use the call rather than forward to have the hooks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot, very nice cleanup! 🔥
layer_outputs = torch.utils.checkpoint.checkpoint( | ||
create_custom_forward(layer_module), | ||
layer_outputs = self.gradient_checkpointing_func( | ||
layer_module.__call__, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's document this in the gradient checkpointing doc (IMO important to know! why forward and call are different)
Ran some training tests with PEFT + GC using this branch and everything seem to pass! Merging once the CI is green |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
* v1 * fix * remove `create_custom_forward` * fixup * fixup * add test and fix all failing GC tests * remove all remaining `create_custom_forward` methods * fix idefics bug * fixup * replace with `__call__` * add comment * quality
## Describe your changes The latest version of transformers (>= 4.35.0) is not compatible with the model. PRs: huggingface/transformers#27020, huggingface/transformers#27073 change the expected signature of `_set_gradient_checkpointing` which now doesn't match the model's https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_mixformer_sequential.py#L802 ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Format your code by running `pre-commit run --all-files` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. ## (Optional) Issue link
* v1 * fix * remove `create_custom_forward` * fixup * fixup * add test and fix all failing GC tests * remove all remaining `create_custom_forward` methods * fix idefics bug * fixup * replace with `__call__` * add comment * quality
Whatis the difference of enable_gradient_checkpointing and gradient_checkpointing_enable?? |
@lucasjinreal I can only see |
🤯 transformers/src/transformers/modeling_utils.py Line 2195 in af4c026
|
What does this PR do?
Alternative to #26917
This way we make
set_gradient_checkpointing
more modulable, as requested by some users - e.g. #21381 (comment)Fixes some issues with DDP such as: huggingface/trl#835
Also removed GC support from
TFSwin
as in theorygradient_checkpointing
is used only for PT models.Added also a CI tests for that
For users that want to use
gradient_checkpointing
withuse_reentrant=False
: