Skip to content

Commit

Permalink
[Wuerstchen] text to image training script (#5052)
Browse files Browse the repository at this point in the history
* initial script

* formatting

* prior trainer wip

* add efficient_net_encoder

* add CLIPTextModel

* add prior ema support

* optimizer

* fix typo

* add dataloader

* prompt_embeds and image_embeds

* intial training loop

* fix output_dir

* fix add_noise

* accelerator check

* make effnet_transforms dynamic

* fix training loop

* add validation logging

* use loaded text_encoder

* use PreTrainedTokenizerFast

* load weigth from pickle

* save_model_card

* remove unused file

* fix typos

* save prior pipeilne in its own folder

* fix imports

* fix pipe_t2i

* scale image_embeds

* remove snr_gamma

* format

* initial lora prior training

* log_validation and save

* initial gradient working

* remove save/load hooks

* set set_attn_processor on prior_prior

* add lora script

* typos

* use LoraLoaderMixin for prior pipeline

* fix usage

* make fix-copies

* yse repo_id

* write_lora_layers is a staitcmethod

* use defualts

* fix defaults

* undo

* Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/loaders.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/loaders.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py

* Update src/diffusers/loaders.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/loaders.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* add graident checkpoint support to prior

* gradient_checkpointing

* formatting

* Update examples/wuerstchen/text_to_image/README.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update examples/wuerstchen/text_to_image/README.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update examples/wuerstchen/text_to_image/README.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update examples/wuerstchen/text_to_image/README.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update examples/wuerstchen/text_to_image/README.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update src/diffusers/loaders.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update examples/wuerstchen/text_to_image/train_text_to_image_prior.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* use default unet and text_encoder

* fix test

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
  • Loading branch information
3 people authored Oct 16, 2023
1 parent 93df5bb commit d03c909
Show file tree
Hide file tree
Showing 10 changed files with 2,094 additions and 34 deletions.
93 changes: 93 additions & 0 deletions examples/wuerstchen/text_to_image/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Würstchen text-to-image fine-tuning

## Running locally with PyTorch

Before running the scripts, make sure to install the library's training dependencies:

**Important**

To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
```

Then cd into the example folder and run
```bash
cd examples/wuerstchen/text_to_image
pip install -r requirements.txt
```

And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:

```bash
accelerate config
```
For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the `--push_to_hub` flag to the training script. To log in, run:
```bash
huggingface-cli login
```

## Prior training

You can fine-tune the Würstchen prior model with the `train_text_to_image_prior.py` script. Note that we currently support `--gradient_checkpointing` for prior model fine-tuning so you can use it for more GPU memory constrained setups.

<br>

<!-- accelerate_snippet_start -->
```bash
export DATASET_NAME="lambdalabs/pokemon-blip-captions"

accelerate launch train_text_to_image_prior.py \
--mixed_precision="fp16" \
--dataset_name=$DATASET_NAME \
--resolution=768 \
--train_batch_size=4 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--dataloader_num_workers=4 \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--checkpoints_total_limit=3 \
--lr_scheduler="constant" --lr_warmup_steps=0 \
--validation_prompts="A robot pokemon, 4k photo" \
--report_to="wandb" \
--push_to_hub \
--output_dir="wuerstchen-prior-pokemon-model"
```
<!-- accelerate_snippet_end -->

## Training with LoRA

Low-Rank Adaption of Large Language Models (or LoRA) was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.

In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:

- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.


### Prior Training

First, you need to set up your development environment as explained in the [installation](#Running-locally-with-PyTorch) section. Make sure to set the `DATASET_NAME` environment variable. Here, we will use the [Pokemon captions dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).

```bash
export DATASET_NAME="lambdalabs/pokemon-blip-captions"

accelerate launch train_text_to_image_prior_lora.py \
--mixed_precision="fp16" \
--dataset_name=$DATASET_NAME --caption_column="text" \
--resolution=768 \
--train_batch_size=8 \
--num_train_epochs=100 --checkpointing_steps=5000 \
--learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
--seed=42 \
--rank=4 \
--validation_prompt="cute dragon creature" \
--report_to="wandb" \
--push_to_hub \
--output_dir="wuerstchen-prior-pokemon-lora"
```
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch.nn as nn
from torchvision.models import efficientnet_v2_l, efficientnet_v2_s

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin


class EfficientNetEncoder(ModelMixin, ConfigMixin):
@register_to_config
def __init__(self, c_latent=16, c_cond=1280, effnet="efficientnet_v2_s"):
super().__init__()

if effnet == "efficientnet_v2_s":
self.backbone = efficientnet_v2_s(weights="DEFAULT").features
else:
self.backbone = efficientnet_v2_l(weights="DEFAULT").features
self.mapper = nn.Sequential(
nn.Conv2d(c_cond, c_latent, kernel_size=1, bias=False),
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
)

def forward(self, x):
return self.mapper(self.backbone(x))
7 changes: 7 additions & 0 deletions examples/wuerstchen/text_to_image/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
accelerate>=0.16.0
torchvision
transformers>=4.25.1
wandb
huggingface-cli
bitsandbytes
deepspeed
Loading

0 comments on commit d03c909

Please sign in to comment.