Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for fine-tuning with LoRA (text2image example) #2002

Closed
wants to merge 57 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
4eb297e
[Lora] first upload
patrickvonplaten Jan 2, 2023
67f4e5a
add first lora version
patrickvonplaten Jan 2, 2023
24993c4
upload
patrickvonplaten Jan 2, 2023
943e7f4
more
patrickvonplaten Jan 2, 2023
e7293d0
first training
patrickvonplaten Jan 3, 2023
0baadb1
Merge branch 'main' of https://github.com/huggingface/diffusers into …
patrickvonplaten Jan 3, 2023
b8e9ce4
up
patrickvonplaten Jan 3, 2023
f7719e0
correct
patrickvonplaten Jan 3, 2023
b69f276
improve
patrickvonplaten Jan 13, 2023
5d6ee56
finish loaders and inference
patrickvonplaten Jan 13, 2023
bc15289
up
patrickvonplaten Jan 13, 2023
c780bdc
feat: add support for fine-tuning with LoRA (text2img).
sayakpaul Jan 16, 2023
67a388f
chore: apply formatting.
sayakpaul Jan 16, 2023
bcc8d12
unet device placement.
sayakpaul Jan 16, 2023
4c811cf
fix: lora layer preparation from accelerate
sayakpaul Jan 16, 2023
d0eff5c
defaulting to weight type code.
sayakpaul Jan 16, 2023
d0a038d
small nit.
sayakpaul Jan 16, 2023
aa8ad74
fix
patrickvonplaten Jan 16, 2023
d334d5a
up
patrickvonplaten Jan 16, 2023
d8f1a6b
fix more
patrickvonplaten Jan 16, 2023
c5cf0a0
up
patrickvonplaten Jan 16, 2023
060697e
finish more
patrickvonplaten Jan 16, 2023
e86d8d6
Merge remote-tracking branch 'origin/add_lora_fine_tuning' into feat/…
sayakpaul Jan 17, 2023
1723671
add: readme notes for LoRA.
sayakpaul Jan 17, 2023
f4f7b41
apply make style and make quality.
sayakpaul Jan 17, 2023
5d5fd77
finish more
patrickvonplaten Jan 17, 2023
608e53c
repo creation fix.
sayakpaul Jan 17, 2023
0998df6
Merge remote-tracking branch 'origin/add_lora_fine_tuning' into feat/…
sayakpaul Jan 17, 2023
b7478ef
up
patrickvonplaten Jan 17, 2023
79c4019
fix: create_repo().
sayakpaul Jan 17, 2023
a2b3ef6
Merge remote-tracking branch 'origin/add_lora_fine_tuning' into feat/…
sayakpaul Jan 17, 2023
1530b76
up
patrickvonplaten Jan 17, 2023
17850de
change year
patrickvonplaten Jan 17, 2023
8f2c81f
changes aligning with dreambooth lora.
sayakpaul Jan 17, 2023
fb8ce5f
revert year change
patrickvonplaten Jan 17, 2023
cbe6ef7
Change lines
patrickvonplaten Jan 17, 2023
0bb68ed
set unet requires_grad False.
sayakpaul Jan 17, 2023
b4bcc26
Add cloneofsimo as co-author.
patrickvonplaten Jan 17, 2023
3d693c0
finish
patrickvonplaten Jan 17, 2023
f53d962
fix docs
patrickvonplaten Jan 17, 2023
f0e36c2
no explicit weight casting.
sayakpaul Jan 17, 2023
2df28cc
disable casting of pixels.
sayakpaul Jan 17, 2023
7013b96
quality changes and edits to readme.
sayakpaul Jan 17, 2023
afc7bf3
change to half-precision for val inference.
sayakpaul Jan 17, 2023
0cef303
Merge branch 'add_lora_fine_tuning' into feat/lora-fit
sayakpaul Jan 17, 2023
3247ef5
fix: device placement.
sayakpaul Jan 17, 2023
4482f77
ifx: code quality.
sayakpaul Jan 17, 2023
c027340
Merge branch 'main' into feat/lora-fit
sayakpaul Jan 17, 2023
66fc1a5
autocasting.
sayakpaul Jan 17, 2023
b6245e2
disable half precision in pipeline.
sayakpaul Jan 17, 2023
7159ca0
fix autocasting.
sayakpaul Jan 17, 2023
c1f1844
device in autocast.
sayakpaul Jan 17, 2023
f4fb615
revert to weight_dtype.
sayakpaul Jan 18, 2023
ca75257
run inference one by one.
sayakpaul Jan 18, 2023
4570ae7
Merge branch 'main' into feat/lora-fit
sayakpaul Jan 18, 2023
16776da
minor nit in the readme.
sayakpaul Jan 18, 2023
866d13c
Merge branch 'main' into feat/lora-fit
sayakpaul Jan 18, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@
title: Configuration
- local: api/outputs
title: Outputs
- local: api/loaders
title: Loaders
title: Main Classes
- sections:
- local: api/pipelines/overview
Expand Down
30 changes: 30 additions & 0 deletions docs/source/en/api/loaders.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Loaders

There are many weights to train adapter neural networks for diffusion models, such as
- [Textual Inversion](./training/text_inversion.mdx)
- [LoRA](https://github.com/cloneofsimo/lora)
- [Hypernetworks](https://arxiv.org/abs/1609.09106)

Such adapter neural networks often only conists of a fraction of the number of weights compared
to the pretrained model and as such are very portable. The Diffusers library offers an easy-to-use
API to load such adapter neural networks via the [`loaders.py` module](https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders.py).

**Note**: This module is still highly experimental and prone to future changes.

## LoaderMixins

### UNet2DConditionLoadersMixin

[[autodoc]] loaders.UNet2DConditionLoadersMixin
2 changes: 1 addition & 1 deletion docs/source/en/api/logging.mdx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
<!--Copyright 2022 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
Expand Down
95 changes: 95 additions & 0 deletions examples/dreambooth/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The `train_dreambooth.py` script shows how to implement the training procedure a


## Running locally with PyTorch

### Installing the dependencies

Before running the scripts, make sure to install the library's training dependencies:
Expand Down Expand Up @@ -235,6 +236,100 @@ image.save("dog-bucket.png")

You can also perform inference from one of the checkpoints saved during the training process, if you used the `--checkpointing_steps` argument. Please, refer to [the documentation](https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint) to see how to do it.

## Training with Low-Rank Adaptation of Large Language Models (LoRA)

Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*

In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114)
- Rank-decomposition matrices have significantly fewer parameters than orginal model which means that trained LoRA weights are easily portable.
- LoRA attention layers allow to control to which extend the model is adapted torwards new training images via a `scale` parameter.

[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in
the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.

### Training

Let's get started with a simple example. We will re-use the dog example of the [previous section](#dog-toy-example).

First, you need to set-up your dreambooth training example as is explained in the [installation section](#Installing-the-dependencies).
Next, let's download the toy dog dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. Make sure to set `INSTANCE_DIR` to the name of your directly further below. This will be our training data.

Now, you can launch the training. Here we will use [Stable Diffusion 1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5).

**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**

**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [wandb](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training to automatically log images.___**


```bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export INSTANCE_DIR="path-to-instance-images"
export OUTPUT_DIR="path-to-save-model"
```

For this example we want to directly store the trained LoRA embeddings on the Hub, so
we need to be logged in and add the `--push_to_hub` flag.

```bash
huggingface-cli login
```

Now we can start training!

```bash
accelerate launch train_dreambooth_lora.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--checkpointing_steps=100 \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--seed="0" \
--push_to_hub
```

**___Note: When using LoRA we can use a much higher learning rate compared to vanilla dreambooth. Here we
use *1e-4* instead of the usual *2e-6*.___**

The final LoRA embedding weights have been uploaded to [patrickvonplaten/lora](https://huggingface.co/patrickvonplaten/lora). **___Note: [The final weights](https://huggingface.co/patrickvonplaten/lora/blob/main/pytorch_attn_procs.bin) are only 3 MB in size which is orders of magnitudes smaller than the original model.**

and the training results are summarized [here](https://wandb.ai/patrickvonplaten/dreambooth/reports/LoRA-DreamBooth-Dog-Example--VmlldzozMzUzMTcx?accessToken=9drrltpimid0jk8q50p91vwovde24cnimc30g3bjd3i5wys5twi7uczd7jdh85dh)

### Inference

After training, LoRA weights can very easily loaded into the original pipeline. First, you need to
load the original pipeline:

```python
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
import torch

pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")
```

Next, we can load the adapter layers into the UNet with the [`load_attn_procs` function](TODO:).

```python
pipe.load_attn_procs("patrickvonplaten/lora")
```

Finally, we can run the model in inference.

```python
image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
```

## Training with Flax/JAX

For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
Expand Down
Loading