Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PEFT] Adapt example scripts to use PEFT #5388

Merged
merged 38 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
2cbc52a
adapt example scripts to use PEFT
younesbelkada Oct 13, 2023
b86543f
Update examples/text_to_image/train_text_to_image_lora.py
younesbelkada Oct 13, 2023
af99c12
fix
younesbelkada Oct 16, 2023
cee4255
Merge remote-tracking branch 'upstream/main' into adapt-example-peft
younesbelkada Oct 23, 2023
89d4bed
add for SDXL
younesbelkada Oct 23, 2023
4281913
oops
younesbelkada Oct 23, 2023
b18d0e8
Merge remote-tracking branch 'upstream/main' into adapt-example-peft
younesbelkada Nov 5, 2023
6a48ad0
make sure to install peft
younesbelkada Nov 5, 2023
069a929
fix
younesbelkada Nov 5, 2023
11fd6f5
Merge remote-tracking branch 'upstream/main' into adapt-example-peft
younesbelkada Nov 14, 2023
e4b0f1d
fix
younesbelkada Nov 14, 2023
62c33c0
fix dreambooth and lora
younesbelkada Nov 14, 2023
a1e1cdf
more fixes
younesbelkada Nov 14, 2023
c3d3002
add peft to requirements.txt
younesbelkada Nov 14, 2023
340150b
fix
younesbelkada Nov 14, 2023
dff2995
final fix
younesbelkada Nov 14, 2023
f9d1b5b
Merge branch 'main' into adapt-example-peft
sayakpaul Nov 16, 2023
978d0cd
add peft version in requirements
younesbelkada Nov 16, 2023
f171404
remove comment
younesbelkada Nov 16, 2023
a2f3f20
change variable names
younesbelkada Nov 16, 2023
14b0dd2
add few lines in readme
younesbelkada Nov 16, 2023
8b3b773
Merge branch 'main' into adapt-example-peft
sayakpaul Nov 17, 2023
b21064f
add to reqs
younesbelkada Nov 17, 2023
7827851
Merge remote-tracking branch 'upstream/main' into adapt-example-peft
younesbelkada Nov 20, 2023
b4e108b
style
younesbelkada Nov 20, 2023
17739ae
Merge branch 'main' into adapt-example-peft
younesbelkada Nov 23, 2023
75c3948
fix issues
younesbelkada Nov 23, 2023
1e94c4b
fix lora dreambooth xl tests
younesbelkada Nov 23, 2023
18552c4
Merge branch 'main' into adapt-example-peft
sayakpaul Nov 27, 2023
fe85d1e
Merge branch 'main' into adapt-example-peft
sayakpaul Nov 29, 2023
ada6ad8
init_lora_weights to gaussian and add out proj where missing
sayakpaul Nov 29, 2023
252dcda
ammend requirements.
sayakpaul Nov 29, 2023
90b760a
ammend requirements.txt
sayakpaul Nov 29, 2023
0692a13
Merge branch 'main' into adapt-example-peft
sayakpaul Nov 29, 2023
f27fb29
Merge branch 'main' into adapt-example-peft
sayakpaul Nov 29, 2023
99df659
Merge remote-tracking branch 'upstream/main' into adapt-example-peft
younesbelkada Dec 6, 2023
32ffcf1
Merge branch 'adapt-example-peft' of https://github.com/younesbelkada…
younesbelkada Dec 6, 2023
57516ed
add correct peft versions
younesbelkada Dec 6, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pr_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ jobs:
- name: Run example PyTorch CPU tests
if: ${{ matrix.config.framework == 'pytorch_examples' }}
run: |
python -m pip install peft
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.report }} \
examples
Expand Down
1 change: 1 addition & 0 deletions examples/dreambooth/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ write_basic_config()
```

When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.

### Dog toy example

Expand Down
1 change: 1 addition & 0 deletions examples/dreambooth/README_sdxl.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ write_basic_config()
```

When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.

### Dog toy example

Expand Down
1 change: 1 addition & 0 deletions examples/dreambooth/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ transformers>=4.25.1
ftfy
tensorboard
Jinja2
peft==0.7.0
1 change: 1 addition & 0 deletions examples/dreambooth/requirements_sdxl.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ transformers>=4.25.1
ftfy
tensorboard
Jinja2
peft==0.7.0
126 changes: 27 additions & 99 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import argparse
import copy
import gc
import itertools
import logging
import math
import os
Expand All @@ -35,6 +34,8 @@
from huggingface_hub import create_repo, upload_folder
from huggingface_hub.utils import insecure_hashlib
from packaging import version
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from PIL import Image
from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset
Expand All @@ -52,14 +53,7 @@
UNet2DConditionModel,
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.models.attention_processor import (
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
SlicedAttnAddedKVProcessor,
)
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available

Expand Down Expand Up @@ -864,79 +858,19 @@ def main(args):
text_encoder.gradient_checkpointing_enable()

# now we will add new LoRA weights to the attention layers
# It's important to realize here how many attention weights will be added and of which sizes
# The sizes of the attention layers consist only of two different variables:
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.

# Let's first see how many attention processors we will have to set.
# For Stable Diffusion, it should be equal to:
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
# - up blocks (2x attention layers) * (3x transformer layers) * (3x up blocks) = 18
# => 32 layers

# Set correct lora layers
unet_lora_parameters = []
for attn_processor_name, attn_processor in unet.attn_processors.items():
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)

# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
)
)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
)
)

# Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())

if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
attn_module.add_k_proj.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.add_k_proj.in_features,
out_features=attn_module.add_k_proj.out_features,
rank=args.rank,
)
)
attn_module.add_v_proj.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.add_v_proj.in_features,
out_features=attn_module.add_v_proj.out_features,
rank=args.rank,
)
)
unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters())
unet_lora_config = LoraConfig(
r=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@younesbelkada in your previous iterations, we were missing "to_out.0" in the target_module for some scripts. I have added that.

)
unet.add_adapter(unet_lora_config)

# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks.
# The text encoder comes from 🤗 transformers, we will also attach adapters to it.
if args.train_text_encoder:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32, rank=args.rank)
text_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
)
text_encoder.add_adapter(text_lora_config)

# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
Expand All @@ -948,9 +882,9 @@ def save_model_hook(models, weights, output_dir):

for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = unet_lora_state_dict(model)
unet_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")

Expand Down Expand Up @@ -1010,11 +944,10 @@ def load_model_hook(models, input_dir):
optimizer_class = torch.optim.AdamW

# Optimizer creation
params_to_optimize = (
itertools.chain(unet_lora_parameters, text_lora_parameters)
if args.train_text_encoder
else unet_lora_parameters
)
params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
if args.train_text_encoder:
params_to_optimize = params_to_optimize + list(filter(lambda p: p.requires_grad, text_encoder.parameters()))
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved

optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
Expand Down Expand Up @@ -1257,12 +1190,7 @@ def compute_text_embeddings(prompt):

accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = (
itertools.chain(unet_lora_parameters, text_lora_parameters)
if args.train_text_encoder
else unet_lora_parameters
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
Expand Down Expand Up @@ -1385,19 +1313,19 @@ def compute_text_embeddings(prompt):
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32)
unet_lora_layers = unet_lora_state_dict(unet)

if text_encoder is not None and args.train_text_encoder:
unet_lora_state_dict = get_peft_model_state_dict(unet)

if args.train_text_encoder:
text_encoder = accelerator.unwrap_model(text_encoder)
text_encoder = text_encoder.to(torch.float32)
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder)
text_encoder_state_dict = get_peft_model_state_dict(text_encoder)
else:
text_encoder_lora_layers = None
text_encoder_state_dict = None

LoraLoaderMixin.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
unet_lora_layers=unet_lora_state_dict,
text_encoder_lora_layers=text_encoder_state_dict,
)

# Final inference
Expand Down
90 changes: 23 additions & 67 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from huggingface_hub import create_repo, upload_folder
from huggingface_hub.utils import insecure_hashlib
from packaging import version
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from PIL import Image
from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset
Expand All @@ -50,9 +52,8 @@
UNet2DConditionModel,
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr, unet_lora_state_dict
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available

Expand Down Expand Up @@ -1009,54 +1010,19 @@ def main(args):
text_encoder_two.gradient_checkpointing_enable()

# now we will add new LoRA weights to the attention layers
# Set correct lora layers
unet_lora_parameters = []
for attn_processor_name, attn_processor in unet.attn_processors.items():
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)

# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
)
)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
)
)

# Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
unet_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
)
unet.add_adapter(unet_lora_config)

# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(
text_encoder_one, dtype=torch.float32, rank=args.rank
)
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
text_encoder_two, dtype=torch.float32, rank=args.rank
text_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
)
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)

# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
Expand All @@ -1069,11 +1035,11 @@ def save_model_hook(models, weights, output_dir):

for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = unet_lora_state_dict(model)
unet_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")

Expand Down Expand Up @@ -1130,6 +1096,12 @@ def load_model_hook(models, input_dir):
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)

unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))

if args.train_text_encoder:
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))

# Optimization parameters
unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate}
if args.train_text_encoder:
Expand Down Expand Up @@ -1194,26 +1166,10 @@ def load_model_hook(models, input_dir):

optimizer_class = prodigyopt.Prodigy

if args.learning_rate <= 0.1:
logger.warn(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
if args.train_text_encoder and args.text_encoder_lr:
logger.warn(
f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate."
)
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
# --learning_rate
params_to_optimize[1]["lr"] = args.learning_rate
params_to_optimize[2]["lr"] = args.learning_rate

optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
decouple=args.prodigy_decouple,
Expand Down Expand Up @@ -1659,13 +1615,13 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32)
unet_lora_layers = unet_lora_state_dict(unet)
unet_lora_layers = get_peft_model_state_dict(unet)

if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32))
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32))
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32))
else:
text_encoder_lora_layers = None
text_encoder_2_lora_layers = None
Expand Down
2 changes: 2 additions & 0 deletions examples/text_to_image/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) e
accelerate config
```

Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.

### Pokemon example

You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
Expand Down
1 change: 1 addition & 0 deletions examples/text_to_image/README_sdxl.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ write_basic_config()
```

When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.

### Training

Expand Down
Loading
Loading