Skip to content

Commit

Permalink
fixed SDXL text encoder training bug huggingface#5016 (huggingface#5078)
Browse files Browse the repository at this point in the history
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
  • Loading branch information
2 people authored and linoytsaban committed Oct 24, 2023
1 parent 87b7087 commit 12dff21
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,11 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
if args.train_text_encoder:
text_encoder_one.train()
text_encoder_two.train()

# set top parameter requires_grad = True for gradient checkpointing works
text_encoder_one.text_model.embeddings.requires_grad_(True)
text_encoder_two.text_model.embeddings.requires_grad_(True)

for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
Expand Down

0 comments on commit 12dff21

Please sign in to comment.