Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SDXL] Allow SDXL LoRA to be run with less than 16GB of VRAM #4470

Merged
merged 13 commits into from
Aug 4, 2023

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Aug 4, 2023

As the issues:

point out training LoRA currently requires a significant amount of memory.
It's quite tough to get SDXL to use less memory given that the unet weights alone are around 6GB in fp16 precision and that we're training on a resolution of 1024 which translates into a latent resolution of 128, things are quite bottlenecked.

In this PR I used the following code snippet:

from diffusers import UNet2DConditionModel
import torch

torch.cuda.set_per_process_memory_fraction(0.4, device="cuda:1")

unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", variant="fp16", torch_dtype=torch.float16)
unet.train()
unet.enable_gradient_checkpointing()
unet = unet.to("cuda:1")

batch_size = 2

sample = torch.randn((1, 4, 128, 128)).half().to(unet.device).repeat(batch_size, 1, 1, 1)
time_ids = (torch.arange(6) / 6)[None, :].half().to(unet.device).repeat(batch_size, 1)
encoder_hidden_states = torch.randn((1, 77, 2048)).half().to(unet.device).repeat(batch_size, 1, 1)
text_embeds = torch.randn((1, 1280)).half().to(unet.device).repeat(batch_size, 1)

out = unet(sample, 1.0, added_cond_kwargs={"time_ids": time_ids, "text_embeds": text_embeds}, encoder_hidden_states=encoder_hidden_states).sample

loss = ((out - sample) ** 2).mean()
loss.backward()

print(torch.cuda.max_memory_allocated(device=unet.device))

to optimize gradient checkpointing more by adding it to the Transformer2D model and the cross attention mid block.

Besides gradient checkpointing, using xformers instead of bitsandbytes for Adam helped quite a bit, so that the final mem requirement for SDXL non-text encoder LoRA training can be brought down to something like 14.8 GB which fits on a T4, thus on a free Google Colab

Screenshot from 2023-08-04 16-49-40

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 4, 2023

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten patrickvonplaten changed the title correct Improve Gradient Checkpointing and memory optimization for SDXL Aug 4, 2023
@@ -899,6 +899,7 @@ def load_model_hook(models, input_dir):

if args.gradient_checkpointing:
controlnet.enable_gradient_checkpointing()
unet.enable_gradient_checkpointing()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some gradients also flow through the unet so it's always safer IMO to enable it as well

@@ -839,6 +839,11 @@ def main(args):
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")

if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for LoRA gradients still flow through the unet so we should enable gradient checkpointing

class_labels=class_labels,
)
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SDXL has 10 Transformer blocks for the inner resolution blocks so checkpointing those can help quite a bit

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.

From the documentation:

Checkpointing is implemented by rerunning a forward-pass segment for each checkpointed segment during backward. This can cause persistent states like the RNG state to be advanced than they would without checkpointing. By default, checkpointing includes logic to juggle the RNG state such that checkpointed passes making use of RNG (through dropout for example) have deterministic output as compared to non-checkpointed passes. The logic to stash and restore RNG states can incur a moderate performance hit depending on the runtime of checkpointed operations. If deterministic output compared to non-checkpointed passes is not required, supply preserve_rng_state=False to checkpoint or checkpoint_sequential to omit stashing and restoring the RNG state during each checkpoint.

Do we need the comparison part between deterministic output and non-checkpointed output? If not, we can also disable preserve_rng_state no?

Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Aug 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

think that's ok that's only relevant for customized gradient backprop workflows IMO

return_dict=False,
)[0]
hidden_states = resnet(hidden_states, temb)
if self.training and self.gradient_checkpointing:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also checkpoint the middle block

@patrickvonplaten patrickvonplaten changed the title Improve Gradient Checkpointing and memory optimization for SDXL [SDXL] Allow SDXL LoRA to be run with less than 16GB of VRAM Aug 4, 2023
@patrickvonplaten
Copy link
Contributor Author

Failing test seems due to flaky internet connection.

+ --mixed_precision="fp16" \
```

and making sure that you have the following libraries installed:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to put a disclaimer that training with text encoder is still not possible within a free-tier Colab?

if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.train_text_encoder:
text_encoder.gradient_checkpointing_enable()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enable_gradient_checkpointing() and gradient_checkpointing_enable() will raise eyebrows but no biggie

Comment on lines -803 to 804
if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge away!

Could you maybe do an empty commit to ensure the CI is 🍏?

@bghira

This comment was marked as duplicate.

@sayakpaul
Copy link
Member

are you also not interested in fully unloading the text encoders since they aren't actually needed for the training?

WDYM?

If train_text_encoder is not specified, then we always precompute the embeddings and free them up.

https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py#L930C1-L934C33

Comment on lines +760 to +761
text_encoder_one.gradient_checkpointing_enable()
text_encoder_two.gradient_checkpointing_enable()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the pix2pix script uses text_encoder_1 and _2 which match the actual internal names. can we make them consistent?

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool!

@patrickvonplaten patrickvonplaten merged commit ea1fcc2 into main Aug 4, 2023
10 checks passed
@patrickvonplaten patrickvonplaten deleted the improve_backprop branch August 4, 2023 18:06
@StanislawKarnacky
Copy link

StanislawKarnacky commented Aug 6, 2023

I successfully ran this code with the free version of colab with T4 GPU. But processing time of 500 steps in ~3 hours is confusing. I trained a lof of regular Loras, but never trained Dreambooth models, especially within those specific conditions, so, sorry for a noob questions. I took all exactly settings from the example with a dog, but how free am I to change them, without running out of RAM/VRAM? Also are the default/example settings universal? For example i have 77 pictures of my own dog, should i add more or reduce the amount of data? And, as i assume, "--max_train_steps" is number of steps after which the model will be saved? Because usually my training ends with 10 epochs, which gives me the opportunity to test and choose the best one.
Thank you.
Update_2: got out of VRAM and RAM twice after all 500 steps.
My settings exactly as in example with dog:
from accelerate.utils import write_basic_config
write_basic_config()
arguments = [
"--pretrained_model_name_or_path", "stabilityai/stable-diffusion-xl-base-1.0",
"--instance_data_dir", "/content/drive/MyDrive/Loras/SDXL/Data",
"--instance_prompt", "A photo of sks dog",
"--output_dir", "/content/drive/MyDrive/Loras/SDXL/Output",
"--seed", "0",
"--train_batch_size", "1",
"--learning_rate", "1e-4",
"--lr_scheduler", "constant",
"--mixed_precision", "fp16",
"--resolution", "1024",
"--gradient_accumulation_steps", "4",
"--lr_warmup_steps", "0",
"--max_train_steps", "500",
"--validation_prompt", "A photo of sks dog",
"--validation_epochs", "25",
"--enable_xformers_memory_efficient_attention",
"--gradient_checkpointing",
"--use_8bit_adam",
]

@StanislawKarnacky
Copy link

StanislawKarnacky commented Aug 8, 2023

I give up for now. After ~ 10 attempts, using another VAE, as suggested in the example with a dog, I managed to get a total time of 500 steps for ~ 1h30m on the Google Colab with T4. But every time after all 500 steps I've run out of VRAM or RAM. Sometimes it's 2Mb, sometimes 20. I duplicate my arguments below. If somebody has any suggestions please let me know!
arguments = [
"--pretrained_model_name_or_path", "stabilityai/stable-diffusion-xl-base-1.0",
"--instance_data_dir", "/content/drive/MyDrive/Loras/SDXL/Data",
"--instance_prompt", "A photo of sks dog",,
"--output_dir", "/content/drive/MyDrive/Loras/SDXL/Output",
"--seed", "0",
"--train_batch_size", "1",
"--learning_rate", "1e-4",
"--lr_scheduler", "constant",
"--mixed_precision", "fp16",
"--resolution", "1024", # Corrected the resolution argument with a comma (,)
"--gradient_accumulation_steps", "4", # Converted integer value to string
"--lr_warmup_steps", "0", # Converted integer value to string
"--max_train_steps", "500", # Converted integer value to string
"--validation_prompt", "A photo of sks dog",
"--validation_epochs", "25", # Converted integer value to string
"--enable_xformers_memory_efficient_attention",
"--gradient_checkpointing",
"--use_8bit_adam",
"--pretrained_vae_model_name_or_path", "madebyollin/sdxl-vae-fp16-fix"
]

@sayakpaul
Copy link
Member

Does this help?

https://colab.research.google.com/drive/1CvDFD-vDohrerhpiNRLtn32ON2ARgBnX?usp=sharing

@StanislawKarnacky
Copy link

StanislawKarnacky commented Aug 8, 2023

Does this help?

https://colab.research.google.com/drive/1CvDFD-vDohrerhpiNRLtn32ON2ARgBnX?usp=sharing

Yes, it is, thank you! I don't know how did you manage to do that. In my case, I just used code from "train_dreambooth_lora_sdxl.py" and used suggested args and "from accelerate.utils import write_basic_config".

@dai-ichiro
Copy link

dai-ichiro commented Aug 9, 2023

I give up for now. After ~ 10 attempts, using another VAE, as suggested in the example with a dog, I managed to get a total time of 500 steps for ~ 1h30m on the Google Colab with T4. But every time after all 500 steps I've run out of VRAM or RAM. Sometimes it's 2Mb, sometimes 20. I duplicate my arguments below. If somebody has any suggestions please let me know! arguments = [ "--pretrained_model_name_or_path", "stabilityai/stable-diffusion-xl-base-1.0", "--instance_data_dir", "/content/drive/MyDrive/Loras/SDXL/Data", "--instance_prompt", "A photo of sks dog",, "--output_dir", "/content/drive/MyDrive/Loras/SDXL/Output", "--seed", "0", "--train_batch_size", "1", "--learning_rate", "1e-4", "--lr_scheduler", "constant", "--mixed_precision", "fp16", "--resolution", "1024", # Corrected the resolution argument with a comma (,) "--gradient_accumulation_steps", "4", # Converted integer value to string "--lr_warmup_steps", "0", # Converted integer value to string "--max_train_steps", "500", # Converted integer value to string "--validation_prompt", "A photo of sks dog", "--validation_epochs", "25", # Converted integer value to string "--enable_xformers_memory_efficient_attention", "--gradient_checkpointing", "--use_8bit_adam", "--pretrained_vae_model_name_or_path", "madebyollin/sdxl-vae-fp16-fix" ]

I got same error.
Default value of checkpointing_steps arg is 500. That may be cause of your OOM.
In my case, when training reach first checkpointint_steps, OOM occurs.

@sayakpaul
Copy link
Member

Try using a larger checkpointing steps than your total maximum training steps.

@dai-ichiro
Copy link

dai-ichiro commented Aug 9, 2023

Thank you for your quick response.

When I use a larger checkpointing steps than total maximum training steps, training and saving results work fine.

I want to resume training from the saved checkpoints, but I don't know how to.
Is there a way to resume training without using checkpointing_steps?

@sayakpaul
Copy link
Member

  • Not sure how that would be possible.
  • For checkpointing, use a larger GPU if possible.

@dai-ichiro
Copy link

Thank you very much.
I will try it soon.

@bghira
Copy link
Contributor

bghira commented Aug 9, 2023

checkpointing likely consumes a lot of VRAM because image validations end up loading the text encoder again.

@StanislawKarnacky
Copy link

StanislawKarnacky commented Aug 9, 2023

@sayakpaul sorry for bothering you, but is there any chance to convert gotten .bin to .safetensors? Or save it in safetensors format straightaway? Kohya_ss script is able to train already to safetensors, but as I presume there is no chance to run it on T4...

yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…face#4470)

* correct

* correct blocks

* finish

* finish

* finish

* Apply suggestions from code review

* fix

* up

* up

* up

* Update examples/dreambooth/README_sdxl.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Apply suggestions from code review

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…face#4470)

* correct

* correct blocks

* finish

* finish

* finish

* Apply suggestions from code review

* fix

* up

* up

* up

* Update examples/dreambooth/README_sdxl.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Apply suggestions from code review

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants