-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SDXL] Allow SDXL LoRA to be run with less than 16GB of VRAM #4470
Conversation
The documentation is not available anymore as the PR was closed or merged. |
@@ -899,6 +899,7 @@ def load_model_hook(models, input_dir): | |||
|
|||
if args.gradient_checkpointing: | |||
controlnet.enable_gradient_checkpointing() | |||
unet.enable_gradient_checkpointing() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some gradients also flow through the unet so it's always safer IMO to enable it as well
@@ -839,6 +839,11 @@ def main(args): | |||
else: | |||
raise ValueError("xformers is not available. Make sure it is installed correctly") | |||
|
|||
if args.gradient_checkpointing: | |||
unet.enable_gradient_checkpointing() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for LoRA gradients still flow through the unet so we should enable gradient checkpointing
class_labels=class_labels, | ||
) | ||
if self.training and self.gradient_checkpointing: | ||
hidden_states = torch.utils.checkpoint.checkpoint( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SDXL has 10 Transformer blocks for the inner resolution blocks so checkpointing those can help quite a bit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense.
From the documentation:
Checkpointing is implemented by rerunning a forward-pass segment for each checkpointed segment during backward. This can cause persistent states like the RNG state to be advanced than they would without checkpointing. By default, checkpointing includes logic to juggle the RNG state such that checkpointed passes making use of RNG (through dropout for example) have deterministic output as compared to non-checkpointed passes. The logic to stash and restore RNG states can incur a moderate performance hit depending on the runtime of checkpointed operations. If deterministic output compared to non-checkpointed passes is not required, supply preserve_rng_state=False to checkpoint or checkpoint_sequential to omit stashing and restoring the RNG state during each checkpoint.
Do we need the comparison part between deterministic output and non-checkpointed output? If not, we can also disable preserve_rng_state
no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
think that's ok that's only relevant for customized gradient backprop workflows IMO
return_dict=False, | ||
)[0] | ||
hidden_states = resnet(hidden_states, temb) | ||
if self.training and self.gradient_checkpointing: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also checkpoint the middle block
Failing test seems due to flaky internet connection. |
+ --mixed_precision="fp16" \ | ||
``` | ||
|
||
and making sure that you have the following libraries installed: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to put a disclaimer that training with text encoder is still not possible within a free-tier Colab?
if args.gradient_checkpointing: | ||
unet.enable_gradient_checkpointing() | ||
if args.train_text_encoder: | ||
text_encoder.gradient_checkpointing_enable() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
enable_gradient_checkpointing()
and gradient_checkpointing_enable()
will raise eyebrows but no biggie
if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)): | ||
if hasattr(module, "gradient_checkpointing"): | ||
module.gradient_checkpointing = value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clean.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merge away!
Could you maybe do an empty commit to ensure the CI is 🍏?
This comment was marked as duplicate.
This comment was marked as duplicate.
WDYM? If |
text_encoder_one.gradient_checkpointing_enable() | ||
text_encoder_two.gradient_checkpointing_enable() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the pix2pix script uses text_encoder_1
and _2 which match the actual internal names. can we make them consistent?
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool!
I successfully ran this code with the free version of colab with T4 GPU. But processing time of 500 steps in ~3 hours is confusing. I trained a lof of regular Loras, but never trained Dreambooth models, especially within those specific conditions, so, sorry for a noob questions. I took all exactly settings from the example with a dog, but how free am I to change them, without running out of RAM/VRAM? Also are the default/example settings universal? For example i have 77 pictures of my own dog, should i add more or reduce the amount of data? And, as i assume, "--max_train_steps" is number of steps after which the model will be saved? Because usually my training ends with 10 epochs, which gives me the opportunity to test and choose the best one. |
I give up for now. After ~ 10 attempts, using another VAE, as suggested in the example with a dog, I managed to get a total time of 500 steps for ~ 1h30m on the Google Colab with T4. But every time after all 500 steps I've run out of VRAM or RAM. Sometimes it's 2Mb, sometimes 20. I duplicate my arguments below. If somebody has any suggestions please let me know! |
Yes, it is, thank you! I don't know how did you manage to do that. In my case, I just used code from "train_dreambooth_lora_sdxl.py" and used suggested args and "from accelerate.utils import write_basic_config". |
I got same error. |
Try using a larger checkpointing steps than your total maximum training steps. |
Thank you for your quick response. When I use a larger checkpointing steps than total maximum training steps, training and saving results work fine. I want to resume training from the saved checkpoints, but I don't know how to. |
|
Thank you very much. |
checkpointing likely consumes a lot of VRAM because image validations end up loading the text encoder again. |
@sayakpaul sorry for bothering you, but is there any chance to convert gotten .bin to .safetensors? Or save it in safetensors format straightaway? Kohya_ss script is able to train already to safetensors, but as I presume there is no chance to run it on T4... |
…face#4470) * correct * correct blocks * finish * finish * finish * Apply suggestions from code review * fix * up * up * up * Update examples/dreambooth/README_sdxl.md Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Apply suggestions from code review --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
…face#4470) * correct * correct blocks * finish * finish * finish * Apply suggestions from code review * fix * up * up * up * Update examples/dreambooth/README_sdxl.md Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Apply suggestions from code review --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
As the issues:
point out training LoRA currently requires a significant amount of memory.
It's quite tough to get SDXL to use less memory given that the unet weights alone are around 6GB in fp16 precision and that we're training on a resolution of 1024 which translates into a latent resolution of 128, things are quite bottlenecked.
In this PR I used the following code snippet:
to optimize gradient checkpointing more by adding it to the Transformer2D model and the cross attention mid block.
Besides gradient checkpointing, using
xformers
instead ofbitsandbytes
for Adam helped quite a bit, so that the final mem requirement for SDXL non-text encoder LoRA training can be brought down to something like 14.8 GB which fits on a T4, thus on a free Google Colab