Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient checkpointing not applied to UNet mid_block #4377

Closed
laksjdjf opened this issue Jul 31, 2023 · 2 comments
Closed

Gradient checkpointing not applied to UNet mid_block #4377

laksjdjf opened this issue Jul 31, 2023 · 2 comments

Comments

@laksjdjf
Copy link
Contributor

The mid block of the SDXL is huge, so fixing it significantly reduces VRAM usage.

I have tested the following changes and have seen great results.
main...laksjdjf:diffusers:mid_block_gradient_checkpointing

However, there seem to be several variations of midblock, including UNetMidBlock2DSimpleCrossAttn and UNetMidBlock3DCrossAttn, and I am not sure what to do with them.

By the way, Kohya's trainer applies gradient checkpointing to all blocks.
https://github.com/kohya-ss/sd-scripts/blob/4072f723c12822e2fa1b2e076cc1f90b8f4e30c9/library/sdxl_original_unet.py#L1035-L1041

@sayakpaul
Copy link
Member

Cc: @patrickvonplaten

@patrickvonplaten
Copy link
Contributor

Great catch - yes for SDXL we should indeed apply gradient checkpointing to the midblock as well :-)

Would you like to open a PR for it? This would be a great addition for the community I believe :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants