From 189f7c3f1e8083d811c049364137f25ea4f74be5 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 18 Jan 2023 19:11:57 +0100 Subject: [PATCH 01/11] example on fine-tuning with LoRA. --- examples/text_to_image/README.md | 74 +- examples/text_to_image/requirements.txt | 1 + .../text_to_image/train_text_to_image_lora.py | 825 ++++++++++++++++++ 3 files changed, 899 insertions(+), 1 deletion(-) create mode 100644 examples/text_to_image/train_text_to_image_lora.py diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index e98e136a4b31..621cd9f790e5 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -110,7 +110,80 @@ image = pipe(prompt="yoda").images[0] image.save("yoda-pokemon.png") ``` +## Training with LoRA +Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. + +In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages: + +- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114). +- Rank-decomposition matrices have significantly fewer parameters than orginal model which means that trained LoRA weights are easily portable. +- LoRA attention layers allow to control to which extent the model is adapted torwards new training images via a `scale` parameter. + +[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository. + +With LoRA, it's possible to fine-tune Stable Diffusion on a custom image-caption pair dataset +on consumer GPUs like Tesla T4, Tesla V100. + +### Training + +First, you need to set-up your dreambooth training example as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Stable Diffusion v1-4](https://hf.co/CompVis/stable-diffusion-v1-4) and the [Pokemons dataset](https://hf.colambdalabs/pokemon-blip-captions). + +**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** + +**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [Weights and Biases](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training to automatically log images.___** + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export DATASET_NAME="lambdalabs/pokemon-blip-captions" +``` + +For this example we want to directly store the trained LoRA embeddings on the Hub, so +we need to be logged in and add the `--push_to_hub` flag. + +```bash +huggingface-cli login +``` + +Now we can start training! + +```bash +accelerate --mixed_precision="fp16" launch train_text_to_image_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_NAME --caption_column="text" \ + --resolution=512 --random_flip \ + --train_batch_size=1 \ + --num_train_epochs=100 --checkpointing_steps=5000 \ + --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \ + --seed=42 \ + --output_dir="sd-pokemon-model-lora" \ + --save_sample_prompt="cute dragon creature" --report_to="wandb" +``` + +The above command will also run inference as fine-tuning progresses and will log the results to Weights and Biases. + +**___Note: When using LoRA we can use a much higher learning rate compared to non-LoRA fine-tuning. Here we use *1e-4* instead of the usual *1e-5*.___** + +The final LoRA embedding weights have been uploaded to [TODO](TODO). **___Note: [The final weights](TODO) are only 3 MB in size which is orders of magnitudes smaller than the original model.** + +### Inference + +Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline` after loading the trained LoRA weights. You +need to pass the `output_dir` for loading the LoRA weights which, in this case, is `sd-pokemon-model-lora`. + +```python +from diffusers import StableDiffusionPipeline +import torch + +model_path = "path_to_saved_model" +pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16) +pipe.unet.load_attn_procs(model_path) +pipe.to("cuda") + +prompt = "A pokemon with green eyes and red legs." +image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] +image.save("pokemon.png") +``` ## Training with Flax/JAX @@ -141,7 +214,6 @@ python train_text_to_image_flax.py \ --output_dir="sd-pokemon-model" ``` - To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata). If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script. diff --git a/examples/text_to_image/requirements.txt b/examples/text_to_image/requirements.txt index b597d5464f1e..062068904f1d 100644 --- a/examples/text_to_image/requirements.txt +++ b/examples/text_to_image/requirements.txt @@ -5,3 +5,4 @@ datasets ftfy tensorboard modelcards +wandb \ No newline at end of file diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py new file mode 100644 index 000000000000..26ed92a2b24e --- /dev/null +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -0,0 +1,825 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fine-tuning script for Stable Diffusion for text2image with support for LoRA.""" + +import argparse +import logging +import math +import os +import random +from pathlib import Path +from typing import Dict, Optional + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint + +import datasets +import diffusers +import transformers +import wandb +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from datasets import load_dataset +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers.loaders import AttnProcsLayers +from diffusers.models.cross_attention import LoRACrossAttnProcessor +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available +from huggingface_hub import HfFolder, Repository, create_repo, whoami +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.12.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + + +class LoraLayers(torch.nn.Module): + def __init__(self, state_dict: Dict[str, torch.Tensor]): + super().__init__() + self.layers = torch.nn.ModuleList(state_dict.values()) + self.mapping = {k: v for k, v in enumerate(state_dict.keys())} + self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} + + # we add a hook to state_dict() and load_state_dict() so that the + # naming fits with `unet.attn_processors` + def map_to(module, state_dict, *args, **kwargs): + new_state_dict = {} + for key, value in state_dict.items(): + num = int(key.split(".")[1]) # 0 is always "layers" + new_key = key.replace(f"layers.{num}", module.mapping[num]) + new_state_dict[new_key] = value + + return new_state_dict + + def map_from(module, state_dict, *args, **kwargs): + all_keys = list(state_dict.keys()) + for key in all_keys: + replace_key = ".".join(key.split(".")[:-3]) + new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") + state_dict[new_key] = state_dict[key] + del state_dict[key] + + self._register_state_dict_hook(map_to) + self._register_load_state_dict_pre_hook(map_from, with_module=True) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference." + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=10, + help=( + "Run fine-tuning validation every X epochs. The validation process consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-model-finetuned-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + action="store_true", + help="Whether to center crop images before resizing to resolution (if not set, random crop will be used)", + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + return args + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +dataset_name_mapping = { + "lambdalabs/pokemon-blip-captions": ("image", "text"), +} + + +def main(): + args = parse_args() + logging_dir = os.path.join(args.output_dir, args.logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + logging_dir=logging_dir, + ) + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo_name = create_repo(repo_name, exist_ok=True) + repo = Repository(args.output_dir, clone_from=repo_name) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load scheduler, tokenizer and models. + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + ) + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision + ) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + unet.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # now we will add new LoRA weights to the attention layers + # It's important to realize here how many attention weights will be added and of which sizes + # The sizes of the attention layers consist only of two different variables: + # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`. + # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`. + + # Let's first see how many attention processors we will have to set. + # For Stable Diffusion, it should be equal to: + # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12 + # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2 + # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18 + # => 32 layers + + # Set correct lora layers + lora_attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRACrossAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + ) + + unet.set_attn_processor(lora_attn_procs) + lora_layers = AttnProcsLayers(unet.attn_processors) + + state_dict = lora_layers.state_dict() + lora_layers.load_state_dict(state_dict) + + accelerator.register_for_checkpointing(lora_layers) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + optimizer = optimizer_cls( + lora_layers.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + dataset_columns = dataset_name_mapping.get(args.dataset_name, None) + if args.image_column is None: + image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.caption_column is None: + caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + + # Preprocessing the datasets. + # We need to tokenize input captions and transform the images. + def tokenize_captions(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + return inputs.input_ids + + # Preprocessing the datasets. + train_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + examples["pixel_values"] = [train_transforms(image) for image in images] + examples["input_ids"] = tokenize_captions(examples) + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + input_ids = torch.stack([example["input_ids"] for example in examples]) + return {"pixel_values": pixel_values, "input_ids": input_ids} + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + # Prepare everything with our `accelerator`. + lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + lora_layers, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("text2image-fine-tune", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = resume_global_step // num_update_steps_per_epoch + resume_step = resume_global_step % num_update_steps_per_epoch + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + with accelerator.accumulate(unet): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + # Predict the noise residual and compute loss + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = lora_layers.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + images = [] + for _ in range(args.num_validation_images): + images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) + + for tracker in accelerator.trackers: + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = unet.to(torch.float32) + unet.save_attn_procs(args.output_dir) + + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + + # Final inference + # Load previous pipeline + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype + ) + pipeline = pipeline.to(accelerator.device) + + # load attention processors + pipeline.unet.load_attn_procs(args.output_dir) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + images = [] + for _ in range(args.num_validation_images): + images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) + + for tracker in accelerator.trackers: + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + accelerator.end_training() + + +if __name__ == "__main__": + main() From 0fe84df6cf7d13b098798426bf36d9cbdc43ff19 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 18 Jan 2023 19:15:48 +0100 Subject: [PATCH 02/11] apply make quality. --- examples/text_to_image/train_text_to_image_lora.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 26ed92a2b24e..6941fd4952bb 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -30,7 +30,6 @@ import datasets import diffusers import transformers -import wandb from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed From 2b908d4a5644c82b0065999e9c8589f57d3aa249 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 19 Jan 2023 10:40:02 +0100 Subject: [PATCH 03/11] fix: pipeline loading. --- examples/text_to_image/train_text_to_image_lora.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 6941fd4952bb..6a2d33bb0a22 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -34,7 +34,7 @@ from accelerate.logging import get_logger from accelerate.utils import set_seed from datasets import load_dataset -from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel from diffusers.loaders import AttnProcsLayers from diffusers.models.cross_attention import LoRACrossAttnProcessor from diffusers.optimization import get_scheduler @@ -753,7 +753,7 @@ def collate_fn(examples): f" {args.validation_prompt}." ) # create pipeline - pipeline = StableDiffusionPipeline.from_pretrained( + pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), revision=args.revision, @@ -793,7 +793,7 @@ def collate_fn(examples): # Final inference # Load previous pipeline - pipeline = StableDiffusionPipeline.from_pretrained( + pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype ) pipeline = pipeline.to(accelerator.device) From 5260ec67e11f57c760c37432c6c4e80b8d59f0d2 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 19 Jan 2023 23:07:03 +0100 Subject: [PATCH 04/11] Apply suggestions from code review Co-authored-by: Suraj Patil Co-authored-by: Patrick von Platen --- examples/text_to_image/README.md | 12 +++++----- .../text_to_image/train_text_to_image_lora.py | 22 +++---------------- 2 files changed, 9 insertions(+), 25 deletions(-) diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index 621cd9f790e5..827d68bed4bd 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -114,11 +114,11 @@ image.save("yoda-pokemon.png") Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. -In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages: +In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages: - Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114). -- Rank-decomposition matrices have significantly fewer parameters than orginal model which means that trained LoRA weights are easily portable. -- LoRA attention layers allow to control to which extent the model is adapted torwards new training images via a `scale` parameter. +- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable. +- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter. [cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository. @@ -127,7 +127,7 @@ on consumer GPUs like Tesla T4, Tesla V100. ### Training -First, you need to set-up your dreambooth training example as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Stable Diffusion v1-4](https://hf.co/CompVis/stable-diffusion-v1-4) and the [Pokemons dataset](https://hf.colambdalabs/pokemon-blip-captions). +First, you need to set up your dreambooth training example as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Stable Diffusion v1-4](https://hf.co/CompVis/stable-diffusion-v1-4) and the [Pokemons dataset](https://hf.colambdalabs/pokemon-blip-captions). **___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** @@ -160,11 +160,11 @@ accelerate --mixed_precision="fp16" launch train_text_to_image_lora.py \ --save_sample_prompt="cute dragon creature" --report_to="wandb" ``` -The above command will also run inference as fine-tuning progresses and will log the results to Weights and Biases. +The above command will also run inference as fine-tuning progresses and log the results to Weights and Biases. **___Note: When using LoRA we can use a much higher learning rate compared to non-LoRA fine-tuning. Here we use *1e-4* instead of the usual *1e-5*.___** -The final LoRA embedding weights have been uploaded to [TODO](TODO). **___Note: [The final weights](TODO) are only 3 MB in size which is orders of magnitudes smaller than the original model.** +The final LoRA embedding weights have been uploaded to [TODO](TODO). **___Note: [The final weights](TODO) are only 3 MB in size, which is orders of magnitudes smaller than the original model.** ### Inference diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 6941fd4952bb..98e5151aeed7 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -52,7 +52,6 @@ logger = get_logger(__name__, log_level="INFO") -class LoraLayers(torch.nn.Module): def __init__(self, state_dict: Dict[str, torch.Tensor]): super().__init__() self.layers = torch.nn.ModuleList(state_dict.values()) @@ -145,7 +144,7 @@ def parse_args(): parser.add_argument( "--validation_epochs", type=int, - default=10, + default=1, help=( "Run fine-tuning validation every X epochs. The validation process consists of running the prompt" " `args.validation_prompt` multiple times: `args.num_validation_images`." @@ -248,16 +247,6 @@ def parse_args(): " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) - parser.add_argument( - "--non_ema_revision", - type=str, - default=None, - required=False, - help=( - "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" - " remote repository specified with --pretrained_model_name_or_path." - ), - ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") @@ -332,9 +321,6 @@ def parse_args(): if args.dataset_name is None and args.train_data_dir is None: raise ValueError("Need either a dataset name or a training folder.") - # default to using the same revision for the non-ema model if not specified - if args.non_ema_revision is None: - args.non_ema_revision = args.revision return args @@ -349,7 +335,7 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: return f"{organization}/{model_id}" -dataset_name_mapping = { +DATASET_NAME_MAPPING = { "lambdalabs/pokemon-blip-captions": ("image", "text"), } @@ -417,7 +403,7 @@ def main(): ) vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) # For mixed precision training we cast the text_encoder and vae weights to half-precision @@ -472,8 +458,6 @@ def main(): unet.set_attn_processor(lora_attn_procs) lora_layers = AttnProcsLayers(unet.attn_processors) - state_dict = lora_layers.state_dict() - lora_layers.load_state_dict(state_dict) accelerator.register_for_checkpointing(lora_layers) From 14dd3ea01162fc7df9c0036dfc7ba1d88223cb24 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 19 Jan 2023 23:18:43 +0100 Subject: [PATCH 05/11] apply suggestions for PR review. Co-authored-by: Patrick von Platen --- .../text_to_image/train_text_to_image_lora.py | 104 +++++++++--------- 1 file changed, 55 insertions(+), 49 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 4665120ef89f..af44c771de40 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -20,7 +20,7 @@ import os import random from pathlib import Path -from typing import Dict, Optional +from typing import Optional import numpy as np import torch @@ -51,34 +51,31 @@ logger = get_logger(__name__, log_level="INFO") - - def __init__(self, state_dict: Dict[str, torch.Tensor]): - super().__init__() - self.layers = torch.nn.ModuleList(state_dict.values()) - self.mapping = {k: v for k, v in enumerate(state_dict.keys())} - self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} - - # we add a hook to state_dict() and load_state_dict() so that the - # naming fits with `unet.attn_processors` - def map_to(module, state_dict, *args, **kwargs): - new_state_dict = {} - for key, value in state_dict.items(): - num = int(key.split(".")[1]) # 0 is always "layers" - new_key = key.replace(f"layers.{num}", module.mapping[num]) - new_state_dict[new_key] = value - - return new_state_dict - - def map_from(module, state_dict, *args, **kwargs): - all_keys = list(state_dict.keys()) - for key in all_keys: - replace_key = ".".join(key.split(".")[:-3]) - new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") - state_dict[new_key] = state_dict[key] - del state_dict[key] - - self._register_state_dict_hook(map_to) - self._register_load_state_dict_pre_hook(map_from, with_module=True) +def save_model_card(repo_name, images=None, base_model=str, dataset_name=str, repo_folder=None): + img_str = "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +tags: +- stable-diffusion +- stable-diffusion-diffusers +- text-to-image +- diffusers +inference: true +--- + """ + model_card = f""" +# LoRA text2image fine-tuning - {repo_name} +These are LoRA adaption weights for {repo_name}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n +{img_str} +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) def parse_args(): @@ -521,7 +518,7 @@ def main(): column_names = dataset["train"].column_names # 6. Get the column names for input/target. - dataset_columns = dataset_name_mapping.get(args.dataset_name, None) + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) if args.image_column is None: image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] else: @@ -752,16 +749,17 @@ def collate_fn(examples): for _ in range(args.num_validation_images): images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) - for tracker in accelerator.trackers: - if tracker.name == "wandb": - tracker.log( - { - "validation": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) + if accelerator.is_main_process: + for tracker in accelerator.trackers: + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) del pipeline torch.cuda.empty_cache() @@ -773,6 +771,13 @@ def collate_fn(examples): unet.save_attn_procs(args.output_dir) if args.push_to_hub: + save_model_card( + repo_name, + images=images, + base_model=args.pretrained_model_name_or_path, + dataset_name=args.dataset_name, + repo_folder=args.output_dir, + ) repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) # Final inference @@ -791,15 +796,16 @@ def collate_fn(examples): for _ in range(args.num_validation_images): images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) - for tracker in accelerator.trackers: - if tracker.name == "wandb": - tracker.log( - { - "test": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) - ] - } - ) + if accelerator.is_main_process: + for tracker in accelerator.trackers: + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) accelerator.end_training() From 73c16f85a4c4f0e5748e57e3929eec2c3e5436b6 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 19 Jan 2023 23:20:29 +0100 Subject: [PATCH 06/11] apply make style and make quality. --- examples/text_to_image/train_text_to_image_lora.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index af44c771de40..3be9eaca1ed2 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -20,7 +20,7 @@ import os import random from pathlib import Path -from typing import Optional +from typing import Optional import numpy as np import torch @@ -51,6 +51,7 @@ logger = get_logger(__name__, log_level="INFO") + def save_model_card(repo_name, images=None, base_model=str, dataset_name=str, repo_folder=None): img_str = "" for i, image in enumerate(images): @@ -318,7 +319,6 @@ def parse_args(): if args.dataset_name is None and args.train_data_dir is None: raise ValueError("Need either a dataset name or a training folder.") - return args @@ -455,7 +455,6 @@ def main(): unet.set_attn_processor(lora_attn_procs) lora_layers = AttnProcsLayers(unet.attn_processors) - accelerator.register_for_checkpointing(lora_layers) # Enable TF32 for faster training on Ampere GPUs, @@ -802,7 +801,8 @@ def collate_fn(examples): tracker.log( { "test": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) ] } ) From 8027f5ff1ef68e994608e015340a1c277ec8c8b8 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 19 Jan 2023 23:23:04 +0100 Subject: [PATCH 07/11] chore: remove mention of dreambooth from text2image. --- examples/text_to_image/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index 827d68bed4bd..8b6983a82237 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -127,7 +127,7 @@ on consumer GPUs like Tesla T4, Tesla V100. ### Training -First, you need to set up your dreambooth training example as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Stable Diffusion v1-4](https://hf.co/CompVis/stable-diffusion-v1-4) and the [Pokemons dataset](https://hf.colambdalabs/pokemon-blip-captions). +First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Stable Diffusion v1-4](https://hf.co/CompVis/stable-diffusion-v1-4) and the [Pokemons dataset](https://hf.colambdalabs/pokemon-blip-captions). **___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** From dc3b6158e9c23de599e0f8d253e0ebe9f4d84f24 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 20 Jan 2023 12:11:11 +0100 Subject: [PATCH 08/11] add: weight path and wandb run link. --- examples/text_to_image/README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index 8b6983a82237..c9b10ea18a8c 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -162,9 +162,11 @@ accelerate --mixed_precision="fp16" launch train_text_to_image_lora.py \ The above command will also run inference as fine-tuning progresses and log the results to Weights and Biases. -**___Note: When using LoRA we can use a much higher learning rate compared to non-LoRA fine-tuning. Here we use *1e-4* instead of the usual *1e-5*.___** +**___Note: When using LoRA we can use a much higher learning rate compared to non-LoRA fine-tuning. Here we use *1e-4* instead of the usual *1e-5*. Also, by using LoRA, it's possible to run `train_text_to_image_lora.py` in consumer GPUs like T4 or V100.** -The final LoRA embedding weights have been uploaded to [TODO](TODO). **___Note: [The final weights](TODO) are only 3 MB in size, which is orders of magnitudes smaller than the original model.** +The final LoRA embedding weights have been uploaded to [sayakpaul/sd-model-finetuned-lora-t4](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4). **___Note: [The final weights](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/pytorch_lora_weights.bin) are only 3 MB in size, which is orders of magnitudes smaller than the original model.** + +You can check some inference samples that were logged during the course of the fine-tuning process [here](https://wandb.ai/sayakpaul/text2image-fine-tune/runs/q4lc0xsw). ### Inference @@ -175,7 +177,7 @@ need to pass the `output_dir` for loading the LoRA weights which, in this case, from diffusers import StableDiffusionPipeline import torch -model_path = "path_to_saved_model" +model_path = "sayakpaul/sd-model-finetuned-lora-t4" pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16) pipe.unet.load_attn_procs(model_path) pipe.to("cuda") From 3c966bc0f1ed9f09272da690e3dce8a68bb81716 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 23 Jan 2023 08:19:45 +0100 Subject: [PATCH 09/11] Apply suggestions from code review --- examples/text_to_image/requirements.txt | 1 - examples/text_to_image/train_text_to_image_lora.py | 1 - 2 files changed, 2 deletions(-) diff --git a/examples/text_to_image/requirements.txt b/examples/text_to_image/requirements.txt index b3f313270abe..9523b3a1a025 100644 --- a/examples/text_to_image/requirements.txt +++ b/examples/text_to_image/requirements.txt @@ -4,6 +4,5 @@ transformers>=4.25.1 datasets ftfy tensorboard -modelcards wandb Jinja2 diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 3be9eaca1ed2..ab5f1290d0fe 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -455,7 +455,6 @@ def main(): unet.set_attn_processor(lora_attn_procs) lora_layers = AttnProcsLayers(unet.attn_processors) - accelerator.register_for_checkpointing(lora_layers) # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices From 55c491c2ccb39eabab7b2060b1f926d58a45d573 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 23 Jan 2023 12:52:15 +0530 Subject: [PATCH 10/11] apply make style. --- examples/text_to_image/train_text_to_image_lora.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index ab5f1290d0fe..82fe32f2cd5b 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -455,7 +455,6 @@ def main(): unet.set_attn_processor(lora_attn_procs) lora_layers = AttnProcsLayers(unet.attn_processors) - # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: From 8e8820c8bc470e5793067f5d4f0424b7961ee35a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 23 Jan 2023 09:22:22 +0200 Subject: [PATCH 11/11] make style --- examples/text_to_image/train_text_to_image_lora.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index ab5f1290d0fe..82fe32f2cd5b 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -455,7 +455,6 @@ def main(): unet.set_attn_processor(lora_attn_procs) lora_layers = AttnProcsLayers(unet.attn_processors) - # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: