Skip to content

Commit

Permalink
make style
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten authored and linoytsaban committed Oct 24, 2023
1 parent 12dff21 commit 0031ed7
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,11 +1323,11 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
if args.train_text_encoder:
text_encoder_one.train()
text_encoder_two.train()

# set top parameter requires_grad = True for gradient checkpointing works
text_encoder_one.text_model.embeddings.requires_grad_(True)
text_encoder_two.text_model.embeddings.requires_grad_(True)

for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
Expand Down

0 comments on commit 0031ed7

Please sign in to comment.