Skip to content

Commit

Permalink
FLUX: Optimize dataset loading logic (#1038)
Browse files Browse the repository at this point in the history
  • Loading branch information
madroidmaq authored Oct 15, 2024
1 parent 3d62b05 commit f491d47
Show file tree
Hide file tree
Showing 6 changed files with 462 additions and 366 deletions.
49 changes: 26 additions & 23 deletions flux/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ The dependencies are minimal, namely:

- `huggingface-hub` to download the checkpoints.
- `regex` for the tokenization
- `tqdm`, `PIL`, and `numpy` for the `txt2image.py` script
- `tqdm`, `PIL`, and `numpy` for the scripts
- `sentencepiece` for the T5 tokenizer
- `datasets` for using an HF dataset directly

You can install all of the above with the `requirements.txt` as follows:

Expand Down Expand Up @@ -118,17 +119,12 @@ Finetuning

The `dreambooth.py` script supports LoRA finetuning of FLUX-dev (and schnell
but ymmv) on a provided image dataset. The dataset folder must have an
`index.json` file with the following format:

```json
{
"data": [
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
...
]
}
`train.jsonl` file with the following format:

```jsonl
{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
...
```

The training script by default trains for 600 iterations with a batch size of
Expand All @@ -150,19 +146,15 @@ The training images are the following 5 images [^2]:

![dog6](static/dog6.png)

We start by making the following `index.json` file and placing it in the same
We start by making the following `train.jsonl` file and placing it in the same
folder as the images.

```json
{
"data": [
{"image": "00.jpg", "text": "A photo of sks dog"},
{"image": "01.jpg", "text": "A photo of sks dog"},
{"image": "02.jpg", "text": "A photo of sks dog"},
{"image": "03.jpg", "text": "A photo of sks dog"},
{"image": "04.jpg", "text": "A photo of sks dog"}
]
}
```jsonl
{"image": "00.jpg", "prompt": "A photo of sks dog"}
{"image": "01.jpg", "prompt": "A photo of sks dog"}
{"image": "02.jpg", "prompt": "A photo of sks dog"}
{"image": "03.jpg", "prompt": "A photo of sks dog"}
{"image": "04.jpg", "prompt": "A photo of sks dog"}
```

Subsequently we finetune FLUX using the following command:
Expand All @@ -175,6 +167,17 @@ python dreambooth.py \
path/to/dreambooth/dataset/dog6
```


Or you can directly use the pre-processed Hugging Face dataset [mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6) for fine-tuning.

```shell
python dreambooth.py \
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
--lora-rank 4 --grad-accumulate 8 \
mlx-community/dreambooth-dog6
```

The training requires approximately 50GB of RAM and on an M2 Ultra it takes a
bit more than 1 hour.

Expand Down
121 changes: 14 additions & 107 deletions flux/dreambooth.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright © 2024 Apple Inc.

import argparse
import json
import time
from functools import partial
from pathlib import Path
Expand All @@ -13,105 +12,8 @@
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image
from tqdm import tqdm

from flux import FluxPipeline


class FinetuningDataset:
def __init__(self, flux, args):
self.args = args
self.flux = flux
self.dataset_base = Path(args.dataset)
dataset_index = self.dataset_base / "index.json"
if not dataset_index.exists():
raise ValueError(f"'{args.dataset}' is not a valid finetuning dataset")
with open(dataset_index, "r") as f:
self.index = json.load(f)

self.latents = []
self.t5_features = []
self.clip_features = []

def _random_crop_resize(self, img):
resolution = self.args.resolution
width, height = img.size

a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()

# Random crop the input image between 0.8 to 1.0 of its original dimensions
crop_size = (
max((0.8 + 0.2 * a) * width, resolution[0]),
max((0.8 + 0.2 * a) * height, resolution[1]),
)
pan = (width - crop_size[0], height - crop_size[1])
img = img.crop(
(
pan[0] * b,
pan[1] * c,
crop_size[0] + pan[0] * b,
crop_size[1] + pan[1] * c,
)
)

# Fit the largest rectangle with the ratio of resolution in the image
# rectangle.
width, height = crop_size
ratio = resolution[0] / resolution[1]
r1 = (height * ratio, height)
r2 = (width, width / ratio)
r = r1 if r1[0] <= width else r2
img = img.crop(
(
(width - r[0]) / 2,
(height - r[1]) / 2,
(width + r[0]) / 2,
(height + r[1]) / 2,
)
)

# Finally resize the image to resolution
img = img.resize(resolution, Image.LANCZOS)

return mx.array(np.array(img))

def encode_images(self):
"""Encode the images in the latent space to prepare for training."""
self.flux.ae.eval()
for sample in tqdm(self.index["data"]):
input_img = Image.open(self.dataset_base / sample["image"])
for i in range(self.args.num_augmentations):
img = self._random_crop_resize(input_img)
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
x_0 = self.flux.ae.encode(img[None])
x_0 = x_0.astype(self.flux.dtype)
mx.eval(x_0)
self.latents.append(x_0)

def encode_prompts(self):
"""Pre-encode the prompts so that we don't recompute them during
training (doesn't allow finetuning the text encoders)."""
for sample in tqdm(self.index["data"]):
t5_tok, clip_tok = self.flux.tokenize([sample["text"]])
t5_feat = self.flux.t5(t5_tok)
clip_feat = self.flux.clip(clip_tok).pooled_output
mx.eval(t5_feat, clip_feat)
self.t5_features.append(t5_feat)
self.clip_features.append(clip_feat)

def iterate(self, batch_size):
xs = mx.concatenate(self.latents)
t5 = mx.concatenate(self.t5_features)
clip = mx.concatenate(self.clip_features)
mx.eval(xs, t5, clip)
n_aug = self.args.num_augmentations
while True:
x_indices = mx.random.permutation(len(self.latents))
c_indices = x_indices // n_aug
for i in range(0, len(self.latents), batch_size):
x_i = x_indices[i : i + batch_size]
c_i = c_indices[i : i + batch_size]
yield xs[x_i], t5[c_i], clip[c_i]
from flux import FluxPipeline, Trainer, load_dataset


def generate_progress_images(iteration, flux, args):
Expand Down Expand Up @@ -157,7 +59,8 @@ def save_adapters(iteration, flux, args):
)


if __name__ == "__main__":
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(
description="Finetune Flux to generate images with a specific subject"
)
Expand Down Expand Up @@ -247,7 +150,11 @@ def save_adapters(iteration, flux, args):
)

parser.add_argument("dataset")
return parser


if __name__ == "__main__":
parser = setup_arg_parser()
args = parser.parse_args()

# Load the model and set it up for LoRA training. We use the same random
Expand All @@ -267,7 +174,7 @@ def save_adapters(iteration, flux, args):
trainable_params = tree_reduce(
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
)
print(f"Training {trainable_params / 1024**2:.3f}M parameters", flush=True)
print(f"Training {trainable_params / 1024 ** 2:.3f}M parameters", flush=True)

# Set up the optimizer and training steps. The steps are a bit verbose to
# support gradient accumulation together with compilation.
Expand Down Expand Up @@ -340,10 +247,10 @@ def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step):
x, t5_feat, clip_feat, guidance, prev_grads
)

print("Create the training dataset.", flush=True)
dataset = FinetuningDataset(flux, args)
dataset.encode_images()
dataset.encode_prompts()
dataset = load_dataset(args.dataset)
trainer = Trainer(flux, dataset, args)
trainer.encode_dataset()

guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)

# An initial generation to compare
Expand All @@ -352,7 +259,7 @@ def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step):
grads = None
losses = []
tic = time.time()
for i, batch in zip(range(args.iterations), dataset.iterate(args.batch_size)):
for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
mx.eval(loss, grads, state)
losses.append(loss.item())
Expand All @@ -361,7 +268,7 @@ def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step):
toc = time.time()
peak_mem = mx.metal.get_peak_memory() / 1024**3
print(
f"Iter: {i+1} Loss: {sum(losses) / 10:.3f} "
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
f"It/s: {10 / (toc - tic):.3f} "
f"Peak mem: {peak_mem:.3f} GB",
flush=True,
Expand Down
Loading

0 comments on commit f491d47

Please sign in to comment.