Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Refactor of gradient_checkpointing #27020

Merged
merged 16 commits into from
Oct 25, 2023

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Oct 23, 2023

What does this PR do?

Alternative to #26917

This way we make set_gradient_checkpointing more modulable, as requested by some users - e.g. #21381 (comment)

Fixes some issues with DDP such as: huggingface/trl#835

Also removed GC support from TFSwin as in theory gradient_checkpointing is used only for PT models.

Added also a CI tests for that

For users that want to use gradient_checkpointing with use_reentrant=False:

...
model.enable_gradient_checkpointing(gradient_checkpointing_kwargs={"use_reentrant": False})

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 23, 2023

The documentation is not available anymore as the PR was closed or merged.

@@ -951,11 +951,6 @@ class TFSwinPreTrainedModel(TFPreTrainedModel):
config_class = SwinConfig
base_model_prefix = "swin"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I removed it because not relevant to TF models

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice cleanup!

@@ -1845,7 +1858,7 @@ def gradient_checkpointing_disable(self):
activations".
"""
if self.supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, value=False))
self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=None))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WHen we disable gradient checkpointing, I think the module.gradient_checkpointing will still be True.
Let's make module.gradient_checkpointing into a property to be sure we always check if the function is none or not WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Property could go at the ModelMixin level ?

@ArthurZucker
Copy link
Collaborator

Can you add a test to make sure setting and unsetting both work as expected (specifically for the fix we are implementing in TRL)

@LysandreJik
Copy link
Member

+1

Comment on lines +305 to +307
# Enable / disable GC for the language model as well
if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"):
self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BLIP2 never propagated gradient_checkpointing to its language_model

Comment on lines +312 to +314
# Enable / disable GC for the language model as well
if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"):
self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Comment on lines +94 to +97
for backbone_module in module.modules():
if hasattr(backbone_module, "gradient_checkpointing"):
backbone_module.gradient_checkpointing_func = gradient_checkpointing_func
backbone_module.gradient_checkpointing = gradient_checkpointing_func is not None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another edge case here where the backbone has some modules that support GC but that attribute never being propagated

Copy link
Contributor Author

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out ~30 architectures were not properly using gradient_checkpointing, I left 3 comments to be aware of

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use the call rather than forward to have the hooks!

@ArthurZucker ArthurZucker mentioned this pull request Oct 25, 2023
2 tasks
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, very nice cleanup! 🔥

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's document this in the gradient checkpointing doc (IMO important to know! why forward and call are different)

@younesbelkada
Copy link
Contributor Author

Ran some training tests with PEFT + GC using this branch and everything seem to pass! Merging once the CI is green

@younesbelkada younesbelkada merged commit 06e782d into huggingface:main Oct 25, 2023
22 checks passed
@younesbelkada younesbelkada deleted the final-fix-gc branch October 25, 2023 10:16
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

i4never pushed a commit to i4never/transformers that referenced this pull request Oct 26, 2023
* v1

* fix

* remove `create_custom_forward`

* fixup

* fixup

* add test and fix all failing GC tests

* remove all remaining `create_custom_forward` methods

* fix idefics bug

* fixup

* replace with `__call__`

* add comment

* quality
jambayk added a commit to microsoft/Olive that referenced this pull request Nov 2, 2023
## Describe your changes
The latest version of transformers (>= 4.35.0) is not compatible with
the model. PRs: huggingface/transformers#27020,
huggingface/transformers#27073 change the
expected signature of `_set_gradient_checkpointing` which now doesn't
match the model's
https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_mixformer_sequential.py#L802

## Checklist before requesting a review
- [ ] Add unit tests for this change.
- [ ] Make sure all tests can pass.
- [ ] Update documents if necessary.
- [ ] Format your code by running `pre-commit run --all-files`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.

## (Optional) Issue link
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
* v1

* fix

* remove `create_custom_forward`

* fixup

* fixup

* add test and fix all failing GC tests

* remove all remaining `create_custom_forward` methods

* fix idefics bug

* fixup

* replace with `__call__`

* add comment

* quality
@lucasjinreal
Copy link

Whatis the difference of enable_gradient_checkpointing and gradient_checkpointing_enable??

@younesbelkada
Copy link
Contributor Author

@lucasjinreal
Copy link

lucasjinreal commented Apr 9, 2024

I using model.enable_gradient_checkpointing no errors appear...

image

How to explain this

@younesbelkada
Copy link
Contributor Author

🤯
I think this was a typo, it is weird that you don't get any error, you should use gradient_checkpointing_enable

def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants