Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PEFT] Adapt example scripts to use PEFT #5388

Merged
merged 38 commits into from
Dec 7, 2023
Merged
Changes from 2 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
2cbc52a
adapt example scripts to use PEFT
younesbelkada Oct 13, 2023
b86543f
Update examples/text_to_image/train_text_to_image_lora.py
younesbelkada Oct 13, 2023
af99c12
fix
younesbelkada Oct 16, 2023
cee4255
Merge remote-tracking branch 'upstream/main' into adapt-example-peft
younesbelkada Oct 23, 2023
89d4bed
add for SDXL
younesbelkada Oct 23, 2023
4281913
oops
younesbelkada Oct 23, 2023
b18d0e8
Merge remote-tracking branch 'upstream/main' into adapt-example-peft
younesbelkada Nov 5, 2023
6a48ad0
make sure to install peft
younesbelkada Nov 5, 2023
069a929
fix
younesbelkada Nov 5, 2023
11fd6f5
Merge remote-tracking branch 'upstream/main' into adapt-example-peft
younesbelkada Nov 14, 2023
e4b0f1d
fix
younesbelkada Nov 14, 2023
62c33c0
fix dreambooth and lora
younesbelkada Nov 14, 2023
a1e1cdf
more fixes
younesbelkada Nov 14, 2023
c3d3002
add peft to requirements.txt
younesbelkada Nov 14, 2023
340150b
fix
younesbelkada Nov 14, 2023
dff2995
final fix
younesbelkada Nov 14, 2023
f9d1b5b
Merge branch 'main' into adapt-example-peft
sayakpaul Nov 16, 2023
978d0cd
add peft version in requirements
younesbelkada Nov 16, 2023
f171404
remove comment
younesbelkada Nov 16, 2023
a2f3f20
change variable names
younesbelkada Nov 16, 2023
14b0dd2
add few lines in readme
younesbelkada Nov 16, 2023
8b3b773
Merge branch 'main' into adapt-example-peft
sayakpaul Nov 17, 2023
b21064f
add to reqs
younesbelkada Nov 17, 2023
7827851
Merge remote-tracking branch 'upstream/main' into adapt-example-peft
younesbelkada Nov 20, 2023
b4e108b
style
younesbelkada Nov 20, 2023
17739ae
Merge branch 'main' into adapt-example-peft
younesbelkada Nov 23, 2023
75c3948
fix issues
younesbelkada Nov 23, 2023
1e94c4b
fix lora dreambooth xl tests
younesbelkada Nov 23, 2023
18552c4
Merge branch 'main' into adapt-example-peft
sayakpaul Nov 27, 2023
fe85d1e
Merge branch 'main' into adapt-example-peft
sayakpaul Nov 29, 2023
ada6ad8
init_lora_weights to gaussian and add out proj where missing
sayakpaul Nov 29, 2023
252dcda
ammend requirements.
sayakpaul Nov 29, 2023
90b760a
ammend requirements.txt
sayakpaul Nov 29, 2023
0692a13
Merge branch 'main' into adapt-example-peft
sayakpaul Nov 29, 2023
f27fb29
Merge branch 'main' into adapt-example-peft
sayakpaul Nov 29, 2023
99df659
Merge remote-tracking branch 'upstream/main' into adapt-example-peft
younesbelkada Dec 6, 2023
32ffcf1
Merge branch 'adapt-example-peft' of https://github.com/younesbelkada…
younesbelkada Dec 6, 2023
57516ed
add correct peft versions
younesbelkada Dec 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 31 additions & 40 deletions examples/text_to_image/train_text_to_image_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
Expand Down Expand Up @@ -439,44 +439,20 @@ def main():
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16

# Freeze the unet parameters before adding adapters
for param in unet.parameters():
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
param.requires_grad_(False)

unet_lora_config = LoraConfig(
r=args.rank, target_modules=["conv1", "conv2", "conv_shortcut", "proj_in", "proj_out"]
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
)

# Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)

# now we will add new LoRA weights to the attention layers
# It's important to realize here how many attention weights will be added and of which sizes
# The sizes of the attention layers consist only of two different variables:
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.

# Let's first see how many attention processors we will have to set.
# For Stable Diffusion, it should be equal to:
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
# - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
# => 32 layers

# Set correct lora layers
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]

lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=args.rank,
)

unet.set_attn_processor(lora_attn_procs)
unet.add_adapter(unet_lora_config)

if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
Expand All @@ -491,7 +467,7 @@ def main():
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")

lora_layers = AttnProcsLayers(unet.attn_processors)
lora_layers = filter(lambda p: p.requires_grad, unet.parameters())

# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
Expand All @@ -517,7 +493,7 @@ def main():
optimizer_cls = torch.optim.AdamW

optimizer = optimizer_cls(
lora_layers.parameters(),
lora_layers,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
Expand Down Expand Up @@ -777,7 +753,7 @@ def collate_fn(examples):
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = lora_layers.parameters()
params_to_clip = lora_layers
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
Expand Down Expand Up @@ -814,6 +790,15 @@ def collate_fn(examples):

save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)

unet_lora_state_dict = get_peft_model_state_dict(unet)

StableDiffusionPipeline.save_lora_weights(
save_directory=save_path,
unet_lora_layers=unet_lora_state_dict,
safe_serialization=True,
)

logger.info(f"Saved state to {save_path}")

logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
Expand Down Expand Up @@ -869,7 +854,13 @@ def collate_fn(examples):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = unet.to(torch.float32)
unet.save_attn_procs(args.output_dir)

unet_lora_state_dict = get_peft_model_state_dict(unet)
DiffusionPipeline.save_pretrained(
args.output_dir,
unet_lora_layers=unet_lora_state_dict,
safe_serialization=True,
)

if args.push_to_hub:
save_model_card(
Expand Down
Loading