From 93df5bb67016a176cab4b58405e4daf5bd1828d9 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 16 Oct 2023 14:11:35 +0530 Subject: [PATCH 1/8] [Examples] fix unconditioning generation training example for mixed-precision training (#5407) * fix: unconditional generation example * fix: float in loss. * apply styling. --- .../train_unconditional.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index a3baa3b85b36..12b63439fa68 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -413,6 +413,14 @@ def load_model_hook(models, input_dir): model_config=model.config, ) + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + args.mixed_precision = accelerator.mixed_precision + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + args.mixed_precision = accelerator.mixed_precision + if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers @@ -559,11 +567,9 @@ def transform_images(examples): progress_bar.update(1) continue - clean_images = batch["input"] + clean_images = batch["input"].to(weight_dtype) # Sample noise that we'll add to the images - noise = torch.randn( - clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16) - ).to(clean_images.device) + noise = torch.randn(clean_images.shape, dtype=weight_dtype, device=clean_images.device) bsz = clean_images.shape[0] # Sample a random timestep for each image timesteps = torch.randint( @@ -579,15 +585,14 @@ def transform_images(examples): model_output = model(noisy_images, timesteps).sample if args.prediction_type == "epsilon": - loss = F.mse_loss(model_output, noise) # this could have different weights! + loss = F.mse_loss(model_output.float(), noise.float()) # this could have different weights! elif args.prediction_type == "sample": alpha_t = _extract_into_tensor( noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1) ) snr_weights = alpha_t / (1 - alpha_t) - loss = snr_weights * F.mse_loss( - model_output, clean_images, reduction="none" - ) # use SNR weighting from distillation paper + # use SNR weighting from distillation paper + loss = snr_weights * F.mse_loss(model_output.float(), clean_images.float(), reduction="none") loss = loss.mean() else: raise ValueError(f"Unsupported prediction type: {args.prediction_type}") From d03c9099bc690870f151035393484c3f5dea2d80 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 16 Oct 2023 15:00:33 +0200 Subject: [PATCH 2/8] [Wuerstchen] text to image training script (#5052) * initial script * formatting * prior trainer wip * add efficient_net_encoder * add CLIPTextModel * add prior ema support * optimizer * fix typo * add dataloader * prompt_embeds and image_embeds * intial training loop * fix output_dir * fix add_noise * accelerator check * make effnet_transforms dynamic * fix training loop * add validation logging * use loaded text_encoder * use PreTrainedTokenizerFast * load weigth from pickle * save_model_card * remove unused file * fix typos * save prior pipeilne in its own folder * fix imports * fix pipe_t2i * scale image_embeds * remove snr_gamma * format * initial lora prior training * log_validation and save * initial gradient working * remove save/load hooks * set set_attn_processor on prior_prior * add lora script * typos * use LoraLoaderMixin for prior pipeline * fix usage * make fix-copies * yse repo_id * write_lora_layers is a staitcmethod * use defualts * fix defaults * undo * Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py Co-authored-by: Patrick von Platen * Update src/diffusers/loaders.py Co-authored-by: Patrick von Platen * Update src/diffusers/loaders.py Co-authored-by: Patrick von Platen * Update src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py * Update src/diffusers/loaders.py Co-authored-by: Patrick von Platen * Update src/diffusers/loaders.py Co-authored-by: Patrick von Platen * add graident checkpoint support to prior * gradient_checkpointing * formatting * Update examples/wuerstchen/text_to_image/README.md Co-authored-by: Pedro Cuenca * Update examples/wuerstchen/text_to_image/README.md Co-authored-by: Pedro Cuenca * Update examples/wuerstchen/text_to_image/README.md Co-authored-by: Pedro Cuenca * Update examples/wuerstchen/text_to_image/README.md Co-authored-by: Pedro Cuenca * Update examples/wuerstchen/text_to_image/README.md Co-authored-by: Pedro Cuenca * Update examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py Co-authored-by: Pedro Cuenca * Update src/diffusers/loaders.py Co-authored-by: Pedro Cuenca * Update examples/wuerstchen/text_to_image/train_text_to_image_prior.py Co-authored-by: Pedro Cuenca * use default unet and text_encoder * fix test --------- Co-authored-by: Patrick von Platen Co-authored-by: Pedro Cuenca --- examples/wuerstchen/text_to_image/README.md | 93 ++ examples/wuerstchen/text_to_image/__init__.py | 0 .../modeling_efficient_net_encoder.py | 23 + .../wuerstchen/text_to_image/requirements.txt | 7 + .../train_text_to_image_lora_prior.py | 888 +++++++++++++++++ .../train_text_to_image_prior.py | 925 ++++++++++++++++++ src/diffusers/loaders.py | 21 +- .../wuerstchen/modeling_wuerstchen_prior.py | 141 ++- .../wuerstchen/pipeline_wuerstchen_prior.py | 5 +- .../schedulers/scheduling_ddpm_wuerstchen.py | 25 +- 10 files changed, 2094 insertions(+), 34 deletions(-) create mode 100644 examples/wuerstchen/text_to_image/README.md create mode 100644 examples/wuerstchen/text_to_image/__init__.py create mode 100644 examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py create mode 100644 examples/wuerstchen/text_to_image/requirements.txt create mode 100644 examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py create mode 100644 examples/wuerstchen/text_to_image/train_text_to_image_prior.py diff --git a/examples/wuerstchen/text_to_image/README.md b/examples/wuerstchen/text_to_image/README.md new file mode 100644 index 000000000000..5378e3ef5253 --- /dev/null +++ b/examples/wuerstchen/text_to_image/README.md @@ -0,0 +1,93 @@ +# Würstchen text-to-image fine-tuning + +## Running locally with PyTorch + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date. To do this, execute the following steps in a new virtual environment: +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install . +``` + +Then cd into the example folder and run +```bash +cd examples/wuerstchen/text_to_image +pip install -r requirements.txt +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` +For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the `--push_to_hub` flag to the training script. To log in, run: +```bash +huggingface-cli login +``` + +## Prior training + +You can fine-tune the Würstchen prior model with the `train_text_to_image_prior.py` script. Note that we currently support `--gradient_checkpointing` for prior model fine-tuning so you can use it for more GPU memory constrained setups. + +
+ + +```bash +export DATASET_NAME="lambdalabs/pokemon-blip-captions" + +accelerate launch train_text_to_image_prior.py \ + --mixed_precision="fp16" \ + --dataset_name=$DATASET_NAME \ + --resolution=768 \ + --train_batch_size=4 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --dataloader_num_workers=4 \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --checkpoints_total_limit=3 \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --validation_prompts="A robot pokemon, 4k photo" \ + --report_to="wandb" \ + --push_to_hub \ + --output_dir="wuerstchen-prior-pokemon-model" +``` + + +## Training with LoRA + +Low-Rank Adaption of Large Language Models (or LoRA) was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. + +In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages: + +- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114). +- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable. +- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter. + + +### Prior Training + +First, you need to set up your development environment as explained in the [installation](#Running-locally-with-PyTorch) section. Make sure to set the `DATASET_NAME` environment variable. Here, we will use the [Pokemon captions dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions). + +```bash +export DATASET_NAME="lambdalabs/pokemon-blip-captions" + +accelerate launch train_text_to_image_prior_lora.py \ + --mixed_precision="fp16" \ + --dataset_name=$DATASET_NAME --caption_column="text" \ + --resolution=768 \ + --train_batch_size=8 \ + --num_train_epochs=100 --checkpointing_steps=5000 \ + --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \ + --seed=42 \ + --rank=4 \ + --validation_prompt="cute dragon creature" \ + --report_to="wandb" \ + --push_to_hub \ + --output_dir="wuerstchen-prior-pokemon-lora" +``` diff --git a/examples/wuerstchen/text_to_image/__init__.py b/examples/wuerstchen/text_to_image/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py b/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py new file mode 100644 index 000000000000..bd551ebf1623 --- /dev/null +++ b/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py @@ -0,0 +1,23 @@ +import torch.nn as nn +from torchvision.models import efficientnet_v2_l, efficientnet_v2_s + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin + + +class EfficientNetEncoder(ModelMixin, ConfigMixin): + @register_to_config + def __init__(self, c_latent=16, c_cond=1280, effnet="efficientnet_v2_s"): + super().__init__() + + if effnet == "efficientnet_v2_s": + self.backbone = efficientnet_v2_s(weights="DEFAULT").features + else: + self.backbone = efficientnet_v2_l(weights="DEFAULT").features + self.mapper = nn.Sequential( + nn.Conv2d(c_cond, c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 + ) + + def forward(self, x): + return self.mapper(self.backbone(x)) diff --git a/examples/wuerstchen/text_to_image/requirements.txt b/examples/wuerstchen/text_to_image/requirements.txt new file mode 100644 index 000000000000..a58ad09eca55 --- /dev/null +++ b/examples/wuerstchen/text_to_image/requirements.txt @@ -0,0 +1,7 @@ +accelerate>=0.16.0 +torchvision +transformers>=4.25.1 +wandb +huggingface-cli +bitsandbytes +deepspeed diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py new file mode 100644 index 000000000000..5235fa99cfdd --- /dev/null +++ b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py @@ -0,0 +1,888 @@ +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import logging +import math +import os +import random +import shutil +from pathlib import Path + +import datasets +import numpy as np +import torch +import torch.nn.functional as F +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.state import AcceleratorState, is_initialized +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, hf_hub_download, upload_folder +from modeling_efficient_net_encoder import EfficientNetEncoder +from torchvision import transforms +from tqdm import tqdm +from transformers import CLIPTextModel, PreTrainedTokenizerFast +from transformers.utils import ContextManagers + +from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler, WuerstchenPriorPipeline +from diffusers.loaders import AttnProcsLayers +from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.optimization import get_scheduler +from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior +from diffusers.utils import check_min_version, is_wandb_available, make_image_grid +from diffusers.utils.logging import set_verbosity_error, set_verbosity_info + + +if is_wandb_available(): + import wandb + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.22.0") + +logger = get_logger(__name__, log_level="INFO") + +DATASET_NAME_MAPPING = { + "lambdalabs/pokemon-blip-captions": ("image", "text"), +} + + +def save_model_card( + args, + repo_id: str, + images=None, + repo_folder=None, +): + img_str = "" + if len(images) > 0: + image_grid = make_image_grid(images, 1, len(args.validation_prompts)) + image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png")) + img_str += "![val_imgs_grid](./val_imgs_grid.png)\n" + + yaml = f""" +--- +license: mit +base_model: {args.pretrained_prior_model_name_or_path} +datasets: +- {args.dataset_name} +tags: +- wuerstchen +- text-to-image +- diffusers +- lora +inference: true +--- + """ + model_card = f""" +# LoRA Finetuning - {repo_id} + +This pipeline was finetuned from **{args.pretrained_prior_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n +{img_str} + +## Pipeline usage + +You can use the pipeline like so: + +```python +from diffusers import DiffusionPipeline +import torch + +pipeline = AutoPipelineForText2Image.from_pretrained( + "{args.pretrained_decoder_model_name_or_path}", torch_dtype={args.weight_dtype} + ) +# load lora weights from folder: +pipeline.prior_pipe.load_lora_weights("{repo_id}", torch_dtype={args.weight_dtype}) + +image = pipeline(prompt=prompt).images[0] +image.save("my_image.png") +``` + +## Training info + +These are the key hyperparameters used during training: + +* LoRA rank: {args.rank} +* Epochs: {args.num_train_epochs} +* Learning rate: {args.learning_rate} +* Batch size: {args.train_batch_size} +* Gradient accumulation steps: {args.gradient_accumulation_steps} +* Image resolution: {args.resolution} +* Mixed-precision: {args.mixed_precision} + +""" + wandb_info = "" + if is_wandb_available(): + wandb_run_url = None + if wandb.run is not None: + wandb_run_url = wandb.run.url + + if wandb_run_url is not None: + wandb_info = f""" +More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}). +""" + + model_card += wandb_info + + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator, weight_dtype, epoch): + logger.info("Running validation... ") + + pipeline = AutoPipelineForText2Image.from_pretrained( + args.pretrained_decoder_model_name_or_path, + prior_text_encoder=accelerator.unwrap_model(text_encoder), + prior_tokenizer=tokenizer, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.prior_prior.set_attn_processor(attn_processors) + pipeline.set_progress_bar_config(disable=True) + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + images = [] + for i in range(len(args.validation_prompts)): + with torch.autocast("cuda"): + image = pipeline( + args.validation_prompts[i], + prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, + generator=generator, + height=args.resolution, + width=args.resolution, + ).images[0] + + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + elif tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}") + for i, image in enumerate(images) + ] + } + ) + else: + logger.warn(f"image logging not implemented for {tracker.name}") + + del pipeline + torch.cuda.empty_cache() + + return images + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of finetuning Würstchen Prior.") + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--pretrained_decoder_model_name_or_path", + type=str, + default="warp-ai/wuerstchen", + required=False, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_prior_model_name_or_path", + type=str, + default="warp-ai/wuerstchen-prior", + required=False, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--validation_prompts", + type=str, + default=None, + nargs="+", + help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="wuerstchen-model-finetuned-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="learning rate", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument( + "--adam_weight_decay", + type=float, + default=0.0, + required=False, + help="weight decay_to_use", + ) + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=5, + help="Run validation every X epochs.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + + return args + + +def main(): + args = parse_args() + logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration( + total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir + ) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load scheduler, effnet, tokenizer, clip_model + noise_scheduler = DDPMWuerstchenScheduler() + tokenizer = PreTrainedTokenizerFast.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="tokenizer" + ) + + def deepspeed_zero_init_disabled_context_manager(): + """ + returns either a context list that includes one that will disable zero.Init or an empty context list + """ + deepspeed_plugin = AcceleratorState().deepspeed_plugin if is_initialized() else None + if deepspeed_plugin is None: + return [] + + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + pretrained_checkpoint_file = hf_hub_download("dome272/wuerstchen", filename="model_v2_stage_b.pt") + state_dict = torch.load(pretrained_checkpoint_file, map_location="cpu") + image_encoder = EfficientNetEncoder() + image_encoder.load_state_dict(state_dict["effnet_state_dict"]) + image_encoder.eval() + + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype + ).eval() + + # Freeze text_encoder, cast to weight_dtype and image_encoder and move to device + text_encoder.requires_grad_(False) + image_encoder.requires_grad_(False) + image_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # load prior model, cast to weight_dtype and move to device + prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior") + prior.to(accelerator.device, dtype=weight_dtype) + + # lora attn processor + lora_attn_procs = {} + for name in prior.attn_processors.keys(): + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=prior.config["c"], rank=args.rank) + prior.set_attn_processor(lora_attn_procs) + lora_layers = AttnProcsLayers(prior.attn_processors) + + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + optimizer = optimizer_cls( + lora_layers.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.image_column is None: + image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.caption_column is None: + caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + + # Preprocessing the datasets. + # We need to tokenize input captions and transform the images + def tokenize_captions(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + text_input_ids = inputs.input_ids + text_mask = inputs.attention_mask.bool() + return text_input_ids, text_mask + + effnet_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ] + ) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + examples["effnet_pixel_values"] = [effnet_transforms(image) for image in images] + examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples) + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + def collate_fn(examples): + effnet_pixel_values = torch.stack([example["effnet_pixel_values"] for example in examples]) + effnet_pixel_values = effnet_pixel_values.to(memory_format=torch.contiguous_format).float() + text_input_ids = torch.stack([example["text_input_ids"] for example in examples]) + text_mask = torch.stack([example["text_mask"] for example in examples]) + return {"effnet_pixel_values": effnet_pixel_values, "text_input_ids": text_input_ids, "text_mask": text_mask} + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + lora_layers, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + tracker_config.pop("validation_prompts") + accelerator.init_trackers(args.tracker_project_name, tracker_config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_train_epochs): + prior.train() + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + with accelerator.accumulate(prior): + # Convert images to latent space + text_input_ids, text_mask, effnet_images = ( + batch["text_input_ids"], + batch["text_mask"], + batch["effnet_pixel_values"].to(weight_dtype), + ) + + with torch.no_grad(): + text_encoder_output = text_encoder(text_input_ids, attention_mask=text_mask) + prompt_embeds = text_encoder_output.last_hidden_state + image_embeds = image_encoder(effnet_images) + # scale + image_embeds = image_embeds.add(1.0).div(42.0) + + # Sample noise that we'll add to the image_embeds + noise = torch.randn_like(image_embeds) + bsz = image_embeds.shape[0] + + # Sample a random timestep for each image + timesteps = torch.rand((bsz,), device=image_embeds.device, dtype=weight_dtype) + + # add noise to latent + noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps) + + # Predict the noise residual and compute losscd + pred_noise = prior(noisy_latents, timesteps, prompt_embeds) + + # vanilla loss + loss = F.mse_loss(pred_noise.float(), noise.float(), reduction="mean") + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(lora_layers.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompts is not None and epoch % args.validation_epochs == 0: + log_validation( + text_encoder, tokenizer, prior.attn_processors, args, accelerator, weight_dtype, global_step + ) + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + prior = prior.to(torch.float32) + WuerstchenPriorPipeline.save_lora_weights( + os.path.join(args.output_dir, "prior_lora"), + unet_lora_layers=lora_layers, + ) + + # Run a final round of inference. + images = [] + if args.validation_prompts is not None: + logger.info("Running inference for collecting generated images...") + pipeline = AutoPipelineForText2Image.from_pretrained( + args.pretrained_decoder_model_name_or_path, + prior_text_encoder=accelerator.unwrap_model(text_encoder), + prior_tokenizer=tokenizer, + ) + pipeline = pipeline.to(accelerator.device, torch_dtype=weight_dtype) + # load lora weights + pipeline.prior_pipe.load_lora_weights(os.path.join(args.output_dir, "prior_lora")) + + pipeline.set_progress_bar_config(disable=True) + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + for i in range(len(args.validation_prompts)): + with torch.autocast("cuda"): + image = pipeline( + args.validation_prompts[i], + prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, + generator=generator, + width=args.resolution, + height=args.resolution, + ).images[0] + images.append(image) + + if args.push_to_hub: + save_model_card(args, repo_id, images, repo_folder=args.output_dir) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py new file mode 100644 index 000000000000..92f63c93fc1a --- /dev/null +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -0,0 +1,925 @@ +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import logging +import math +import os +import random +import shutil +from pathlib import Path + +import accelerate +import datasets +import numpy as np +import torch +import torch.nn.functional as F +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.state import AcceleratorState, is_initialized +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, hf_hub_download, upload_folder +from modeling_efficient_net_encoder import EfficientNetEncoder +from packaging import version +from torchvision import transforms +from tqdm import tqdm +from transformers import CLIPTextModel, PreTrainedTokenizerFast +from transformers.utils import ContextManagers + +from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler +from diffusers.optimization import get_scheduler +from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior +from diffusers.training_utils import EMAModel +from diffusers.utils import check_min_version, is_wandb_available, make_image_grid +from diffusers.utils.logging import set_verbosity_error, set_verbosity_info + + +if is_wandb_available(): + import wandb + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.22.0") + +logger = get_logger(__name__, log_level="INFO") + +DATASET_NAME_MAPPING = { + "lambdalabs/pokemon-blip-captions": ("image", "text"), +} + + +def save_model_card( + args, + repo_id: str, + images=None, + repo_folder=None, +): + img_str = "" + if len(images) > 0: + image_grid = make_image_grid(images, 1, len(args.validation_prompts)) + image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png")) + img_str += "![val_imgs_grid](./val_imgs_grid.png)\n" + + yaml = f""" +--- +license: mit +base_model: {args.pretrained_prior_model_name_or_path} +datasets: +- {args.dataset_name} +tags: +- wuerstchen +- text-to-image +- diffusers +inference: true +--- + """ + model_card = f""" +# Finetuning - {repo_id} + +This pipeline was finetuned from **{args.pretrained_prior_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n +{img_str} + +## Pipeline usage + +You can use the pipeline like so: + +```python +from diffusers import DiffusionPipeline +import torch + +pipe_prior = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype={args.weight_dtype}) +pipe_t2i = DiffusionPipeline.from_pretrained("{args.pretrained_decoder_model_name_or_path}", torch_dtype={args.weight_dtype}) +prompt = "{args.validation_prompts[0]}" +(image_embeds,) = pipe_prior(prompt).to_tuple() +image = pipe_t2i(image_embeddings=image_embeds, prompt=prompt).images[0] +image.save("my_image.png") +``` + +## Training info + +These are the key hyperparameters used during training: + +* Epochs: {args.num_train_epochs} +* Learning rate: {args.learning_rate} +* Batch size: {args.train_batch_size} +* Gradient accumulation steps: {args.gradient_accumulation_steps} +* Image resolution: {args.resolution} +* Mixed-precision: {args.mixed_precision} + +""" + wandb_info = "" + if is_wandb_available(): + wandb_run_url = None + if wandb.run is not None: + wandb_run_url = wandb.run.url + + if wandb_run_url is not None: + wandb_info = f""" +More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}). +""" + + model_card += wandb_info + + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, epoch): + logger.info("Running validation... ") + + pipeline = AutoPipelineForText2Image.from_pretrained( + args.pretrained_decoder_model_name_or_path, + prior_prior=accelerator.unwrap_model(prior), + prior_text_encoder=accelerator.unwrap_model(text_encoder), + prior_tokenizer=tokenizer, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + images = [] + for i in range(len(args.validation_prompts)): + with torch.autocast("cuda"): + image = pipeline( + args.validation_prompts[i], + prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, + generator=generator, + height=args.resolution, + width=args.resolution, + ).images[0] + + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + elif tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}") + for i, image in enumerate(images) + ] + } + ) + else: + logger.warn(f"image logging not implemented for {tracker.name}") + + del pipeline + torch.cuda.empty_cache() + + return images + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of finetuning Würstchen Prior.") + parser.add_argument( + "--pretrained_decoder_model_name_or_path", + type=str, + default="warp-ai/wuerstchen", + required=False, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_prior_model_name_or_path", + type=str, + default="warp-ai/wuerstchen-prior", + required=False, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--validation_prompts", + type=str, + default=None, + nargs="+", + help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="wuerstchen-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="learning rate", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument( + "--adam_weight_decay", + type=float, + default=0.0, + required=False, + help="weight decay_to_use", + ) + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=5, + help="Run validation every X epochs.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + + return args + + +def main(): + args = parse_args() + logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration( + total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir + ) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load scheduler, effnet, tokenizer, clip_model + noise_scheduler = DDPMWuerstchenScheduler() + tokenizer = PreTrainedTokenizerFast.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="tokenizer" + ) + + def deepspeed_zero_init_disabled_context_manager(): + """ + returns either a context list that includes one that will disable zero.Init or an empty context list + """ + deepspeed_plugin = AcceleratorState().deepspeed_plugin if is_initialized() else None + if deepspeed_plugin is None: + return [] + + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + pretrained_checkpoint_file = hf_hub_download("dome272/wuerstchen", filename="model_v2_stage_b.pt") + state_dict = torch.load(pretrained_checkpoint_file, map_location="cpu") + image_encoder = EfficientNetEncoder() + image_encoder.load_state_dict(state_dict["effnet_state_dict"]) + image_encoder.eval() + + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype + ).eval() + + # Freeze text_encoder and image_encoder + text_encoder.requires_grad_(False) + image_encoder.requires_grad_(False) + + # load prior model + prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior") + + # Create EMA for the prior + if args.use_ema: + ema_prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior") + ema_prior = EMAModel(ema_prior.parameters(), model_cls=WuerstchenPrior, model_config=ema_prior.config) + ema_prior.to(accelerator.device) + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if args.use_ema: + ema_prior.save_pretrained(os.path.join(output_dir, "prior_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "prior")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "prior_ema"), WuerstchenPrior) + ema_prior.load_state_dict(load_model.state_dict()) + ema_prior.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = WuerstchenPrior.from_pretrained(input_dir, subfolder="prior") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + prior.enable_gradient_checkpointing() + + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + optimizer = optimizer_cls( + prior.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.image_column is None: + image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.caption_column is None: + caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + + # Preprocessing the datasets. + # We need to tokenize input captions and transform the images + def tokenize_captions(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + text_input_ids = inputs.input_ids + text_mask = inputs.attention_mask.bool() + return text_input_ids, text_mask + + effnet_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ] + ) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + examples["effnet_pixel_values"] = [effnet_transforms(image) for image in images] + examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples) + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + def collate_fn(examples): + effnet_pixel_values = torch.stack([example["effnet_pixel_values"] for example in examples]) + effnet_pixel_values = effnet_pixel_values.to(memory_format=torch.contiguous_format).float() + text_input_ids = torch.stack([example["text_input_ids"] for example in examples]) + text_mask = torch.stack([example["text_mask"] for example in examples]) + return {"effnet_pixel_values": effnet_pixel_values, "text_input_ids": text_input_ids, "text_mask": text_mask} + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + prior, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + prior, optimizer, train_dataloader, lr_scheduler + ) + image_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + tracker_config.pop("validation_prompts") + accelerator.init_trackers(args.tracker_project_name, tracker_config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_train_epochs): + prior.train() + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + with accelerator.accumulate(prior): + # Convert images to latent space + text_input_ids, text_mask, effnet_images = ( + batch["text_input_ids"], + batch["text_mask"], + batch["effnet_pixel_values"].to(weight_dtype), + ) + + with torch.no_grad(): + text_encoder_output = text_encoder(text_input_ids, attention_mask=text_mask) + prompt_embeds = text_encoder_output.last_hidden_state + image_embeds = image_encoder(effnet_images) + # scale + image_embeds = image_embeds.add(1.0).div(42.0) + + # Sample noise that we'll add to the image_embeds + noise = torch.randn_like(image_embeds) + bsz = image_embeds.shape[0] + + # Sample a random timestep for each image + timesteps = torch.rand((bsz,), device=image_embeds.device, dtype=weight_dtype) + + # add noise to latent + noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps) + + # Predict the noise residual and compute losscd + pred_noise = prior(noisy_latents, timesteps, prompt_embeds) + + # vanilla loss + loss = F.mse_loss(pred_noise.float(), noise.float(), reduction="mean") + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(prior.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.use_ema: + ema_prior.step(prior.parameters()) + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompts is not None and epoch % args.validation_epochs == 0: + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_prior.store(prior.parameters()) + ema_prior.copy_to(prior.parameters()) + log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, global_step) + if args.use_ema: + # Switch back to the original UNet parameters. + ema_prior.restore(prior.parameters()) + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + prior = accelerator.unwrap_model(prior) + if args.use_ema: + ema_prior.copy_to(prior.parameters()) + + pipeline = AutoPipelineForText2Image.from_pretrained( + args.pretrained_decoder_model_name_or_path, + prior_prior=prior, + prior_text_encoder=accelerator.unwrap_model(text_encoder), + prior_tokenizer=tokenizer, + ) + pipeline.prior_pipe.save_pretrained(os.path.join(args.output_dir, "prior_pipeline")) + + # Run a final round of inference. + images = [] + if args.validation_prompts is not None: + logger.info("Running inference for collecting generated images...") + pipeline = pipeline.to(accelerator.device, torch_dtype=weight_dtype) + pipeline.set_progress_bar_config(disable=True) + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + for i in range(len(args.validation_prompts)): + with torch.autocast("cuda"): + image = pipeline( + args.validation_prompts[i], + prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, + generator=generator, + width=args.resolution, + height=args.resolution, + ).images[0] + images.append(image) + + if args.push_to_hub: + save_model_card(args, repo_id, images, repo_folder=args.output_dir) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 695a22d955da..483030b06c7f 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1208,7 +1208,7 @@ def load_lora_weights( self.load_lora_into_unet( state_dict, network_alphas=network_alphas, - unet=self.unet, + unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, low_cpu_mem_usage=low_cpu_mem_usage, adapter_name=adapter_name, _pipeline=self, @@ -1216,7 +1216,9 @@ def load_lora_weights( self.load_lora_into_text_encoder( state_dict, network_alphas=network_alphas, - text_encoder=self.text_encoder, + text_encoder=getattr(self, self.text_encoder_name) + if not hasattr(self, "text_encoder") + else self.text_encoder, lora_scale=self.lora_scale, low_cpu_mem_usage=low_cpu_mem_usage, adapter_name=adapter_name, @@ -1577,7 +1579,7 @@ def load_lora_into_unet( """ low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), - # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as + # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) @@ -1961,7 +1963,7 @@ def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameter @classmethod def save_lora_weights( - self, + cls, save_directory: Union[str, os.PathLike], unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, @@ -2001,7 +2003,7 @@ def save_lora_weights( unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers ) - unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()} + unet_lora_state_dict = {f"{cls.unet_name}.{module_name}": param for module_name, param in weights.items()} state_dict.update(unet_lora_state_dict) if text_encoder_lora_layers is not None: @@ -2012,12 +2014,12 @@ def save_lora_weights( ) text_encoder_lora_state_dict = { - f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items() + f"{cls.text_encoder_name}.{module_name}": param for module_name, param in weights.items() } state_dict.update(text_encoder_lora_state_dict) # Save the model - self.write_lora_layers( + cls.write_lora_layers( state_dict=state_dict, save_directory=save_directory, is_main_process=is_main_process, @@ -2026,6 +2028,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) + @staticmethod def write_lora_layers( state_dict: Dict[str, torch.Tensor], save_directory: str, @@ -3248,7 +3251,7 @@ def load_lora_weights( @classmethod def save_lora_weights( - self, + cls, save_directory: Union[str, os.PathLike], unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, @@ -3299,7 +3302,7 @@ def pack_weights(layers, prefix): state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) - self.write_lora_layers( + cls.write_lora_layers( state_dict=state_dict, save_directory=save_directory, is_main_process=is_main_process, diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index 9bd29b59b3af..ca72ce581fcc 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -14,16 +14,29 @@ # limitations under the License. import math +from typing import Dict, Union import torch import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import UNet2DConditionLoadersMixin +from ...models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) from ...models.modeling_utils import ModelMixin +from ...utils import is_torch_version from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm -class WuerstchenPrior(ModelMixin, ConfigMixin): +class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + unet_name = "prior" + _supports_gradient_checkpointing = True + @register_to_config def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1): super().__init__() @@ -45,6 +58,90 @@ def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dro nn.Conv2d(c, c_in * 2, kernel_size=1), ) + self.gradient_checkpointing = False + self.set_default_attn_processor() + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor, _remove_lora=_remove_lora) + else: + module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor, _remove_lora=True) + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + def gen_r_embedding(self, r, max_positions=10000): r = r * max_positions half_dim = self.c_r // 2 @@ -61,12 +158,42 @@ def forward(self, x, r, c): x = self.projection(x) c_embed = self.cond_mapper(c) r_embed = self.gen_r_embedding(r) - for block in self.blocks: - if isinstance(block, AttnBlock): - x = block(x, c_embed) - elif isinstance(block, TimestepBlock): - x = block(x, r_embed) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + for block in self.blocks: + if isinstance(block, AttnBlock): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), x, c_embed, use_reentrant=False + ) + elif isinstance(block, TimestepBlock): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), x, r_embed, use_reentrant=False + ) + else: + x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False) else: - x = block(x) + for block in self.blocks: + if isinstance(block, AttnBlock): + x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, c_embed) + elif isinstance(block, TimestepBlock): + x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, r_embed) + else: + x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x) + else: + for block in self.blocks: + if isinstance(block, AttnBlock): + x = block(x, c_embed) + elif isinstance(block, TimestepBlock): + x = block(x, r_embed) + else: + x = block(x) a, b = self.out(x).chunk(2, dim=1) return (x_in - a) / ((1 - b).abs() + 1e-5) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index dba6d7bb06db..bc8e9cd998c0 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -20,6 +20,7 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer +from ...loaders import LoraLoaderMixin from ...schedulers import DDPMWuerstchenScheduler from ...utils import ( BaseOutput, @@ -65,7 +66,7 @@ class WuerstchenPriorPipelineOutput(BaseOutput): image_embeddings: Union[torch.FloatTensor, np.ndarray] -class WuerstchenPriorPipeline(DiffusionPipeline): +class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin): """ Pipeline for generating image prior for Wuerstchen. @@ -90,6 +91,8 @@ class WuerstchenPriorPipeline(DiffusionPipeline): Default resolution for multiple images generated. """ + unet_name = "prior" + text_encoder_name = "text_encoder" model_cpu_offload_seq = "text_encoder->prior" def __init__( diff --git a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py index 781efb12b18b..bafa6d7f1b87 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py +++ b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py @@ -211,24 +211,15 @@ def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.IntTensor, + timesteps: torch.FloatTensor, ) -> torch.FloatTensor: - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) - - sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples + device = original_samples.device + dtype = original_samples.dtype + alpha_cumprod = self._alpha_cumprod(timesteps, device=device).view( + timesteps.size(0), *[1 for _ in original_samples.shape[1:]] + ) + noisy_samples = alpha_cumprod.sqrt() * original_samples + (1 - alpha_cumprod).sqrt() * noise + return noisy_samples.to(dtype=dtype) def __len__(self): return self.config.num_train_timesteps From 5495073faf05790e48c4e27fe2826b467b69cf4a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 16 Oct 2023 18:41:37 +0530 Subject: [PATCH 3/8] [Docs] add docs on peft diffusers integration (#5359) * add docs on peft diffusers integration/ Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: pacman100 <13534540+pacman100@users.noreply.github.com> * update URLs. * Apply suggestions from code review Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * minor changes * Update docs/source/en/tutorials/using_peft_for_inference.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * reflect the latest changes. * note about update. --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: pacman100 <13534540+pacman100@users.noreply.github.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/_toctree.yml | 2 + .../en/tutorials/using_peft_for_inference.md | 165 ++++++++++++++++++ 2 files changed, 167 insertions(+) create mode 100644 docs/source/en/tutorials/using_peft_for_inference.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index b8aa71dacbe2..88da548bd597 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -17,6 +17,8 @@ title: AutoPipeline - local: tutorials/basic_training title: Train a diffusion model + - local: tutorials/using_peft_for_inference + title: Inference with PEFT title: Tutorials - sections: - sections: diff --git a/docs/source/en/tutorials/using_peft_for_inference.md b/docs/source/en/tutorials/using_peft_for_inference.md new file mode 100644 index 000000000000..4629cf8ba43c --- /dev/null +++ b/docs/source/en/tutorials/using_peft_for_inference.md @@ -0,0 +1,165 @@ + + +[[open-in-colab]] + +# Inference with PEFT + +There are many adapters trained in different styles to achieve different effects. You can even combine multiple adapters to create new and unique images. With the 🤗 [PEFT](https://huggingface.co/docs/peft/index) integration in 🤗 Diffusers, it is really easy to load and manage adapters for inference. In this guide, you'll learn how to use different adapters with [Stable Diffusion XL (SDXL)](./pipelines/stable_diffusion/stable_diffusion_xl) for inference. + +Throughout this guide, you'll use LoRA as the main adapter technique, so we'll use the terms LoRA and adapter interchangeably. You should have some familiarity with LoRA, and if you don't, we welcome you to check out the [LoRA guide](https://huggingface.co/docs/peft/conceptual_guides/lora). + +Let's first install all the required libraries. + +```bash +!pip install -q transformers accelerate +# Will be updated once the stable releases are done. +!pip install -q git+https://github.com/huggingface/peft.git +!pip install -q git+https://github.com/huggingface/diffusers.git +``` + +Now, let's load a pipeline with a SDXL checkpoint: + +```python +from diffusers import DiffusionPipeline +import torch + +pipe_id = "stabilityai/stable-diffusion-xl-base-1.0" +pipe = DiffusionPipeline.from_pretrained(pipe_id, torch_dtype=torch.float16).to("cuda") +``` + + +Next, load a LoRA checkpoint with the [`~diffusers.loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] method. + +With the 🤗 PEFT integration, you can assign a specific `adapter_name` to the checkpoint, which let's you easily switch between different LoRA checkpoints. Let's call this adapter `"toy"`. + +```python +pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy") +``` + +And then perform inference: + +```python +prompt = "toy_face of a hacker with a hoodie" + +lora_scale= 0.9 +image = pipe( + prompt, num_inference_steps=30, cross_attention_kwargs={"scale": lora_scale}, generator=torch.manual_seed(0) +).images[0] +image +``` + +![toy-face](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_8_1.png) + + +With the `adapter_name` parameter, it is really easy to use another adapter for inference! Load the [nerijs/pixel-art-xl](https://huggingface.co/nerijs/pixel-art-xl) adapter that has been fine-tuned to generate pixel art images, and let's call it `"pixel"`. + +The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter. But you can activate the `"pixel"` adapter with the [`~diffusers.loaders.set_adapters`] method as shown below: + +```python +pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") +pipe.set_adapters("pixel") +``` + +Let's now generate an image with the second adapter and check the result: + +```python +prompt = "a hacker with a hoodie, pixel art" +image = pipe( + prompt, num_inference_steps=30, cross_attention_kwargs={"scale": lora_scale}, generator=torch.manual_seed(0) +).images[0] +image +``` + +![pixel-art](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_12_1.png) + +## Combine multiple adapters + +You can also perform multi-adapter inference where you combine different adapter checkpoints for inference. + +Once again, use the [`~diffusers.loaders.set_adapters`] method to activate two LoRA checkpoints and specify the weight for how the checkpoints should be combined. + +```python +pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0]) +``` + +Now that we have set these two adapters, let's generate an image from the combined adapters! + + + +LoRA checkpoints in the diffusion community are almost always obtained with [DreamBooth](https://huggingface.co/docs/diffusers/main/en/training/dreambooth). DreamBooth training often relies on "trigger" words in the input text prompts in order for the generation results to look as expected. When you combine multiple LoRA checkpoints, it's important to ensure the trigger words for the corresponding LoRA checkpoints are present in the input text prompts. + + + +The trigger words for [CiroN2022/toy-face](https://hf.co/CiroN2022/toy-face) and [nerijs/pixel-art-xl](https://hf.co/nerijs/pixel-art-xl) are found in their repositories. + + +```python +# Notice how the prompt is constructed. +prompt = "toy_face of a hacker with a hoodie, pixel art" +image = pipe( + prompt, num_inference_steps=30, cross_attention_kwargs={"scale": 1.0}, generator=torch.manual_seed(0) +).images[0] +image +``` + +![toy-face-pixel-art](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_16_1.png) + +Impressive! As you can see, the model was able to generate an image that mixes the characteristics of both adapters. + +If you want to go back to using only one adapter, use the [`~diffusers.loaders.set_adapters`] method to activate the `"toy"` adapter: + +```python +# First, set the adapter. +pipe.set_adapters("toy") + +# Then, run inference. +prompt = "toy_face of a hacker with a hoodie" +lora_scale= 0.9 +image = pipe( + prompt, num_inference_steps=30, cross_attention_kwargs={"scale": lora_scale}, generator=torch.manual_seed(0) +).images[0] +image +``` + +![toy-face-again](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_18_1.png) + + +If you want to switch to only the base model, disable all LoRAs with the [`~diffusers.loaders.disable_lora`] method. + + +```python +pipe.disable_lora() + +prompt = "toy_face of a hacker with a hoodie" +lora_scale= 0.9 +image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0] +image +``` + +![no-lora](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_20_1.png) + +## Monitoring active adapters + +You have attached multiple adapters in this tutorial, and if you're feeling a bit lost on what adapters have been attached to the pipeline's components, you can easily check the list of active adapters using the [`~diffusers.loaders.get_active_adapters`] method: + +```python +active_adapters = pipe.get_active_adapters() +>>> ["toy", "pixel"] +``` + +You can also get the active adapters of each pipeline component with [`~diffusers.loaders.get_list_adapters`]: + +```python +list_adapters_component_wise = pipe.get_list_adapters() +>>> {"text_encoder": ["toy", "pixel"], "unet": ["toy", "pixel"], "text_encoder_2": ["toy", "pixel"]} +``` From 0ea78f9707f3ccb0fe3eaa74f247e4d3d8f47b6b Mon Sep 17 00:00:00 2001 From: Heinz-Alexander Fuetterer <35225576+afuetterer@users.noreply.github.com> Date: Mon, 16 Oct 2023 15:23:37 +0200 Subject: [PATCH 4/8] chore: fix typos (#5386) * chore: fix typos * Update src/diffusers/pipelines/shap_e/renderer.py Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com> --------- Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com> --- CONTRIBUTING.md | 6 +++--- docs/source/en/api/pipelines/diffedit.md | 2 +- docs/source/en/api/pipelines/kandinsky.md | 2 +- docs/source/en/api/pipelines/kandinsky_v22.md | 2 +- docs/source/en/conceptual/contribution.md | 6 +++--- docs/source/en/optimization/torch2.0.md | 2 +- docs/source/en/training/custom_diffusion.md | 4 ++-- docs/source/en/training/text_inversion.md | 2 +- docs/source/en/using-diffusers/controlnet.md | 2 +- docs/source/en/using-diffusers/shap-e.md | 2 +- examples/community/run_onnx_controlnet.py | 2 +- examples/community/run_tensorrt_controlnet.py | 2 +- examples/custom_diffusion/README.md | 4 ++-- examples/dreambooth/train_dreambooth.py | 2 +- examples/dreambooth/train_dreambooth_lora.py | 2 +- examples/dreambooth/train_dreambooth_lora_sdxl.py | 2 +- examples/research_projects/colossalai/README.md | 2 +- .../research_projects/multi_subject_dreambooth/README.md | 2 +- examples/research_projects/sdxl_flax/README.md | 2 +- src/diffusers/pipelines/shap_e/renderer.py | 4 ++-- .../pipelines/unidiffuser/modeling_text_decoder.py | 2 +- utils/custom_init_isort.py | 2 +- utils/release.py | 2 +- 23 files changed, 30 insertions(+), 30 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ae2be777aa37..124d2adf1ce5 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -40,7 +40,7 @@ In the following, we give an overview of different ways to contribute, ranked by As said before, **all contributions are valuable to the community**. In the following, we will explain each contribution a bit more in detail. -For all contributions 4.-9. you will need to open a PR. It is explained in detail how to do so in [Opening a pull requst](#how-to-open-a-pr) +For all contributions 4.-9. you will need to open a PR. It is explained in detail how to do so in [Opening a pull request](#how-to-open-a-pr) ### 1. Asking and answering questions on the Diffusers discussion forum or on the Diffusers Discord @@ -63,7 +63,7 @@ In the same spirit, you are of immense help to the community by answering such q **Please** keep in mind that the more effort you put into asking or answering a question, the higher the quality of the publicly documented knowledge. In the same way, well-posed and well-answered questions create a high-quality knowledge database accessible to everybody, while badly posed questions or answers reduce the overall quality of the public knowledge database. -In short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accesible*, and *well-formated/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section. +In short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accessible*, and *well-formated/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section. **NOTE about channels**: [*The forum*](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) is much better indexed by search engines, such as Google. Posts are ranked by popularity rather than chronologically. Hence, it's easier to look up questions and answers that we posted some time ago. @@ -168,7 +168,7 @@ more precise, provide the link to a duplicated issue or redirect them to [the fo If you have verified that the issued bug report is correct and requires a correction in the source code, please have a look at the next sections. -For all of the following contributions, you will need to open a PR. It is explained in detail how to do so in the [Opening a pull requst](#how-to-open-a-pr) section. +For all of the following contributions, you will need to open a PR. It is explained in detail how to do so in the [Opening a pull request](#how-to-open-a-pr) section. ### 4. Fixing a "Good first issue" diff --git a/docs/source/en/api/pipelines/diffedit.md b/docs/source/en/api/pipelines/diffedit.md index ef698ff33d1b..2ba7f9092907 100644 --- a/docs/source/en/api/pipelines/diffedit.md +++ b/docs/source/en/api/pipelines/diffedit.md @@ -34,7 +34,7 @@ this in the generated mask, you simply have to set the embeddings related to the `source_prompt` and "dog" to `target_prompt`. * When generating partially inverted latents using `invert`, assign a caption or text embedding describing the overall image to the `prompt` argument to help guide the inverse latent sampling process. In most cases, the -source concept is sufficently descriptive to yield good results, but feel free to explore alternatives. +source concept is sufficiently descriptive to yield good results, but feel free to explore alternatives. * When calling the pipeline to generate the final edited image, assign the source concept to `negative_prompt` and the target concept to `prompt`. Taking the above example, you simply have to set the embeddings related to the phrases including "cat" to `negative_prompt` and "dog" to `prompt`. diff --git a/docs/source/en/api/pipelines/kandinsky.md b/docs/source/en/api/pipelines/kandinsky.md index 069c7996053a..086821a3bc0a 100644 --- a/docs/source/en/api/pipelines/kandinsky.md +++ b/docs/source/en/api/pipelines/kandinsky.md @@ -396,7 +396,7 @@ t2i_pipe.unet.set_attn_processor(AttnAddedKVProcessor()) ``` With PyTorch >= 2.0, you can also use Kandinsky with `torch.compile` which depending -on your hardware can signficantly speed-up your inference time once the model is compiled. +on your hardware can significantly speed-up your inference time once the model is compiled. To use Kandinsksy with `torch.compile`, you can do: ```py diff --git a/docs/source/en/api/pipelines/kandinsky_v22.md b/docs/source/en/api/pipelines/kandinsky_v22.md index 3f88997ff4f5..4967ccc0a5cd 100644 --- a/docs/source/en/api/pipelines/kandinsky_v22.md +++ b/docs/source/en/api/pipelines/kandinsky_v22.md @@ -263,7 +263,7 @@ t2i_pipe.unet.set_attn_processor(AttnAddedKVProcessor()) ``` With PyTorch >= 2.0, you can also use Kandinsky with `torch.compile` which depending -on your hardware can signficantly speed-up your inference time once the model is compiled. +on your hardware can significantly speed-up your inference time once the model is compiled. To use Kandinsksy with `torch.compile`, you can do: ```py diff --git a/docs/source/en/conceptual/contribution.md b/docs/source/en/conceptual/contribution.md index ea1d15f2124c..74393ebf3eb3 100644 --- a/docs/source/en/conceptual/contribution.md +++ b/docs/source/en/conceptual/contribution.md @@ -40,7 +40,7 @@ In the following, we give an overview of different ways to contribute, ranked by As said before, **all contributions are valuable to the community**. In the following, we will explain each contribution a bit more in detail. -For all contributions 4.-9. you will need to open a PR. It is explained in detail how to do so in [Opening a pull requst](#how-to-open-a-pr) +For all contributions 4.-9. you will need to open a PR. It is explained in detail how to do so in [Opening a pull request](#how-to-open-a-pr) ### 1. Asking and answering questions on the Diffusers discussion forum or on the Diffusers Discord @@ -63,7 +63,7 @@ In the same spirit, you are of immense help to the community by answering such q **Please** keep in mind that the more effort you put into asking or answering a question, the higher the quality of the publicly documented knowledge. In the same way, well-posed and well-answered questions create a high-quality knowledge database accessible to everybody, while badly posed questions or answers reduce the overall quality of the public knowledge database. -In short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accesible*, and *well-formated/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section. +In short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accessible*, and *well-formated/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section. **NOTE about channels**: [*The forum*](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) is much better indexed by search engines, such as Google. Posts are ranked by popularity rather than chronologically. Hence, it's easier to look up questions and answers that we posted some time ago. @@ -168,7 +168,7 @@ more precise, provide the link to a duplicated issue or redirect them to [the fo If you have verified that the issued bug report is correct and requires a correction in the source code, please have a look at the next sections. -For all of the following contributions, you will need to open a PR. It is explained in detail how to do so in the [Opening a pull requst](#how-to-open-a-pr) section. +For all of the following contributions, you will need to open a PR. It is explained in detail how to do so in the [Opening a pull request](#how-to-open-a-pr) section. ### 4. Fixing a `Good first issue` diff --git a/docs/source/en/optimization/torch2.0.md b/docs/source/en/optimization/torch2.0.md index c0d3a037b9b1..1e07b876514f 100644 --- a/docs/source/en/optimization/torch2.0.md +++ b/docs/source/en/optimization/torch2.0.md @@ -70,7 +70,7 @@ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) images = pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images[0] ``` -Depending on GPU type, `torch.compile` can provide an *addtional speed-up* of **5-300x** on top of SDPA! If you're using more recent GPU architectures such as Ampere (A100, 3090), Ada (4090), and Hopper (H100), `torch.compile` is able to squeeze even more performance out of these GPUs. +Depending on GPU type, `torch.compile` can provide an *additional speed-up* of **5-300x** on top of SDPA! If you're using more recent GPU architectures such as Ampere (A100, 3090), Ada (4090), and Hopper (H100), `torch.compile` is able to squeeze even more performance out of these GPUs. Compilation requires some time to complete, so it is best suited for situations where you prepare your pipeline once and then perform the same type of inference operations multiple times. For example, calling the compiled pipeline on a different image size triggers compilation again which can be expensive. diff --git a/docs/source/en/training/custom_diffusion.md b/docs/source/en/training/custom_diffusion.md index 2c9156d65f31..153ae81f1216 100644 --- a/docs/source/en/training/custom_diffusion.md +++ b/docs/source/en/training/custom_diffusion.md @@ -69,7 +69,7 @@ write_basic_config() Now let's get our dataset. Download dataset from [here](https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip) and unzip it. To use your own dataset, take a look at the [Create a dataset for training](create_dataset) guide. -We also collect 200 real images using `clip-retrieval` which are combined with the target images in the training dataset as a regularization. This prevents overfitting to the the given target image. The following flags enable the regularization `with_prior_preservation`, `real_prior` with `prior_loss_weight=1.`. +We also collect 200 real images using `clip-retrieval` which are combined with the target images in the training dataset as a regularization. This prevents overfitting to the given target image. The following flags enable the regularization `with_prior_preservation`, `real_prior` with `prior_loss_weight=1.`. The `class_prompt` should be the category name same as target image. The collected real images are with text captions similar to the `class_prompt`. The retrieved image are saved in `class_data_dir`. You can disable `real_prior` to use generated images as regularization. To collect the real images use this command first before training. ```bash @@ -106,7 +106,7 @@ accelerate launch train_custom_diffusion.py \ **Use `--enable_xformers_memory_efficient_attention` for faster training with lower VRAM requirement (16GB per GPU). Follow [this guide](https://github.com/facebookresearch/xformers) for installation instructions.** -To track your experiments using Weights and Biases (`wandb`) and to save intermediate results (whcih we HIGHLY recommend), follow these steps: +To track your experiments using Weights and Biases (`wandb`) and to save intermediate results (which we HIGHLY recommend), follow these steps: * Install `wandb`: `pip install wandb`. * Authorize: `wandb login`. diff --git a/docs/source/en/training/text_inversion.md b/docs/source/en/training/text_inversion.md index 48904c32371b..7cc7d57e7c6c 100644 --- a/docs/source/en/training/text_inversion.md +++ b/docs/source/en/training/text_inversion.md @@ -192,7 +192,7 @@ been added to the text encoder embedding matrix and consequently been trained. 💡 The community has created a large library of different textual inversion embedding vectors, called [sd-concepts-library](https://huggingface.co/sd-concepts-library). -Instead of training textual inversion embeddings from scratch you can also see whether a fitting textual inversion embedding has already been added to the libary. +Instead of training textual inversion embeddings from scratch you can also see whether a fitting textual inversion embedding has already been added to the library. diff --git a/docs/source/en/using-diffusers/controlnet.md b/docs/source/en/using-diffusers/controlnet.md index be02e999e1b8..5ecf0748d275 100644 --- a/docs/source/en/using-diffusers/controlnet.md +++ b/docs/source/en/using-diffusers/controlnet.md @@ -434,7 +434,7 @@ high_threshold = 200 canny_image = cv2.Canny(canny_image, low_threshold, high_threshold) -# zero out middle columns of image where pose will be overlayed +# zero out middle columns of image where pose will be overlaid zero_start = canny_image.shape[1] // 4 zero_end = zero_start + canny_image.shape[1] // 2 canny_image[:, zero_start:zero_end] = 0 diff --git a/docs/source/en/using-diffusers/shap-e.md b/docs/source/en/using-diffusers/shap-e.md index b74a652582ec..68542bf56773 100644 --- a/docs/source/en/using-diffusers/shap-e.md +++ b/docs/source/en/using-diffusers/shap-e.md @@ -62,7 +62,7 @@ export_to_gif(images[1], "cake_3d.gif") ## Image-to-3D -To generate a 3D object from another image, use the [`ShapEImg2ImgPipeline`]. You can use an existing image or generate an entirely new one. Let's use the the [Kandinsky 2.1](../api/pipelines/kandinsky) model to generate a new image. +To generate a 3D object from another image, use the [`ShapEImg2ImgPipeline`]. You can use an existing image or generate an entirely new one. Let's use the [Kandinsky 2.1](../api/pipelines/kandinsky) model to generate a new image. ```py from diffusers import DiffusionPipeline diff --git a/examples/community/run_onnx_controlnet.py b/examples/community/run_onnx_controlnet.py index 2b1123a4955c..69181b0a545e 100644 --- a/examples/community/run_onnx_controlnet.py +++ b/examples/community/run_onnx_controlnet.py @@ -553,7 +553,7 @@ def __call__( instead. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): - The initial image will be used as the starting point for the image generation process. Can also accpet + The initial image will be used as the starting point for the image generation process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded again. control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): diff --git a/examples/community/run_tensorrt_controlnet.py b/examples/community/run_tensorrt_controlnet.py index 724f393eb122..9fef7187ab79 100644 --- a/examples/community/run_tensorrt_controlnet.py +++ b/examples/community/run_tensorrt_controlnet.py @@ -657,7 +657,7 @@ def __call__( instead. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): - The initial image will be used as the starting point for the image generation process. Can also accpet + The initial image will be used as the starting point for the image generation process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded again. control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): diff --git a/examples/custom_diffusion/README.md b/examples/custom_diffusion/README.md index 9e3c387e3d34..e686933feb51 100644 --- a/examples/custom_diffusion/README.md +++ b/examples/custom_diffusion/README.md @@ -48,7 +48,7 @@ write_basic_config() Now let's get our dataset. Download dataset from [here](https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip) and unzip it. -We also collect 200 real images using `clip-retrieval` which are combined with the target images in the training dataset as a regularization. This prevents overfitting to the the given target image. The following flags enable the regularization `with_prior_preservation`, `real_prior` with `prior_loss_weight=1.`. +We also collect 200 real images using `clip-retrieval` which are combined with the target images in the training dataset as a regularization. This prevents overfitting to the given target image. The following flags enable the regularization `with_prior_preservation`, `real_prior` with `prior_loss_weight=1.`. The `class_prompt` should be the category name same as target image. The collected real images are with text captions similar to the `class_prompt`. The retrieved image are saved in `class_data_dir`. You can disable `real_prior` to use generated images as regularization. To collect the real images use this command first before training. ```bash @@ -82,7 +82,7 @@ accelerate launch train_custom_diffusion.py \ **Use `--enable_xformers_memory_efficient_attention` for faster training with lower VRAM requirement (16GB per GPU). Follow [this guide](https://github.com/facebookresearch/xformers) for installation instructions.** -To track your experiments using Weights and Biases (`wandb`) and to save intermediate results (whcih we HIGHLY recommend), follow these steps: +To track your experiments using Weights and Biases (`wandb`) and to save intermediate results (which we HIGHLY recommend), follow these steps: * Install `wandb`: `pip install wandb`. * Authorize: `wandb login`. diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 606cc5c6cfdd..6ad79a47deb5 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -1119,7 +1119,7 @@ def compute_text_embeddings(prompt): unet, optimizer, train_dataloader, lr_scheduler ) - # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision + # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index ac72974c4a1c..493430cadbdf 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -794,7 +794,7 @@ def main(args): text_encoder.requires_grad_(False) unet.requires_grad_(False) - # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision + # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 8ef666840b3a..caf04f430838 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -707,7 +707,7 @@ def main(args): text_encoder_two.requires_grad_(False) unet.requires_grad_(False) - # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision + # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": diff --git a/examples/research_projects/colossalai/README.md b/examples/research_projects/colossalai/README.md index 7c428bbce736..be94950b772e 100644 --- a/examples/research_projects/colossalai/README.md +++ b/examples/research_projects/colossalai/README.md @@ -41,7 +41,7 @@ The `text` include the tag `Teyvat`, `Name`,`Element`, `Weapon`, `Region`, `Mode ## Training -The arguement `placement` can be `cpu`, `auto`, `cuda`, with `cpu` the GPU RAM required can be minimized to 4GB but will deceleration, with `cuda` you can also reduce GPU memory by half but accelerated training, with `auto` a more balanced solution for speed and memory can be obtained。 +The argument `placement` can be `cpu`, `auto`, `cuda`, with `cpu` the GPU RAM required can be minimized to 4GB but will deceleration, with `cuda` you can also reduce GPU memory by half but accelerated training, with `auto` a more balanced solution for speed and memory can be obtained。 **___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** diff --git a/examples/research_projects/multi_subject_dreambooth/README.md b/examples/research_projects/multi_subject_dreambooth/README.md index d1a7705cfebb..5fff305f82be 100644 --- a/examples/research_projects/multi_subject_dreambooth/README.md +++ b/examples/research_projects/multi_subject_dreambooth/README.md @@ -323,7 +323,7 @@ accelerate launch train_dreambooth.py \ ### Using DreamBooth for other pipelines than Stable Diffusion -Altdiffusion also support dreambooth now, the runing comman is basically the same as abouve, all you need to do is replace the `MODEL_NAME` like this: +Altdiffusion also support dreambooth now, the runing comman is basically the same as above, all you need to do is replace the `MODEL_NAME` like this: One can now simply change the `pretrained_model_name_or_path` to another architecture such as [`AltDiffusion`](https://huggingface.co/docs/diffusers/api/pipelines/alt_diffusion). ``` diff --git a/examples/research_projects/sdxl_flax/README.md b/examples/research_projects/sdxl_flax/README.md index fca21912982a..612fdf1edd43 100644 --- a/examples/research_projects/sdxl_flax/README.md +++ b/examples/research_projects/sdxl_flax/README.md @@ -151,7 +151,7 @@ telling JAX which input arguments are static, that is, arguments that are known at compile time and won't change. In our case, it is num_inference_steps, height, width and return_latents. -Once the function is compiled, these parameters are ommited from future calls and +Once the function is compiled, these parameters are omitted from future calls and cannot be changed without modifying the code and recompiling. ```python diff --git a/src/diffusers/pipelines/shap_e/renderer.py b/src/diffusers/pipelines/shap_e/renderer.py index ac5c06042e59..2145bc25c40a 100644 --- a/src/diffusers/pipelines/shap_e/renderer.py +++ b/src/diffusers/pipelines/shap_e/renderer.py @@ -911,7 +911,7 @@ def decode_to_image( n_coarse_samples=64, n_fine_samples=128, ): - # project the the paramters from the generated latents + # project the parameters from the generated latents projected_params = self.params_proj(latents) # update the mlp layers of the renderer @@ -955,7 +955,7 @@ def decode_to_mesh( query_batch_size: int = 4096, texture_channels: Tuple = ("R", "G", "B"), ): - # 1. project the the paramters from the generated latents + # 1. project the parameters from the generated latents projected_params = self.params_proj(latents) # 2. update the mlp layers of the renderer diff --git a/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py b/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py index 9b962f6e0656..bf0a4eb475c0 100644 --- a/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py +++ b/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py @@ -20,7 +20,7 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): prefix_length (`int`): Max number of prefix tokens that will be supplied to the model. prefix_inner_dim (`int`): - The hidden size of the the incoming prefix embeddings. For UniDiffuser, this would be the hidden dim of the + The hidden size of the incoming prefix embeddings. For UniDiffuser, this would be the hidden dim of the CLIP text encoder. prefix_hidden_dim (`int`, *optional*): Hidden dim of the MLP if we encode the prefix. diff --git a/utils/custom_init_isort.py b/utils/custom_init_isort.py index e1e85974aeed..2de3940342d0 100644 --- a/utils/custom_init_isort.py +++ b/utils/custom_init_isort.py @@ -16,7 +16,7 @@ Utility that sorts the imports in the custom inits of Diffusers. Diffusers uses init files that delay the import of an object to when it's actually needed. This is to avoid the main init importing all models, which would make the line `import transformers` very slow when the user has all optional dependencies installed. The inits with -delayed imports have two halves: one definining a dictionary `_import_structure` which maps modules to the name of the +delayed imports have two halves: one defining a dictionary `_import_structure` which maps modules to the name of the objects in each module, and one in `TYPE_CHECKING` which looks like a normal init for type-checkers. `isort` or `ruff` properly sort the second half which looks like traditionl imports, the goal of this script is to sort the first half. diff --git a/utils/release.py b/utils/release.py index 758fb70caaca..a0800b99fbeb 100644 --- a/utils/release.py +++ b/utils/release.py @@ -130,7 +130,7 @@ def pre_release_work(patch=False): def post_release_work(): - """Do all the necesarry post-release steps.""" + """Do all the necessary post-release steps.""" # First let's get the current version current_version = get_version() dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0" From cc12f3ec929d0359080cb6233c598e4f6d1f4f69 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 16 Oct 2023 19:34:46 +0530 Subject: [PATCH 5/8] [Examples] Update with HFApi (#5393) * update training examples to use HFAPI. * update training example. * reflect the changes in the korean version too. * Empty-Commit --- docs/source/en/tutorials/basic_training.md | 29 ++++++-------- docs/source/ko/tutorials/basic_training.md | 32 +++++++-------- examples/dreambooth/train_dreambooth_flax.py | 40 ++++++------------- .../textual_inversion.py | 40 +++++++------------ .../train_unconditional.py | 40 ++++++------------- .../train_unconditional.py | 40 ++++++------------- 6 files changed, 79 insertions(+), 142 deletions(-) diff --git a/docs/source/en/tutorials/basic_training.md b/docs/source/en/tutorials/basic_training.md index c97447e54bc1..b2243a7597f6 100644 --- a/docs/source/en/tutorials/basic_training.md +++ b/docs/source/en/tutorials/basic_training.md @@ -284,22 +284,11 @@ Now you can wrap all these components together in a training loop with 🤗 Acce ```py >>> from accelerate import Accelerator ->>> from huggingface_hub import HfFolder, Repository, whoami +>>> from huggingface_hub import create_repo, upload_folder >>> from tqdm.auto import tqdm >>> from pathlib import Path >>> import os - ->>> def get_full_repo_name(model_id: str, organization: str = None, token: str = None): -... if token is None: -... token = HfFolder.get_token() -... if organization is None: -... username = whoami(token)["name"] -... return f"{username}/{model_id}" -... else: -... return f"{organization}/{model_id}" - - >>> def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler): ... # Initialize accelerator and tensorboard logging ... accelerator = Accelerator( @@ -309,11 +298,12 @@ Now you can wrap all these components together in a training loop with 🤗 Acce ... project_dir=os.path.join(config.output_dir, "logs"), ... ) ... if accelerator.is_main_process: -... if config.push_to_hub: -... repo_name = get_full_repo_name(Path(config.output_dir).name) -... repo = Repository(config.output_dir, clone_from=repo_name) -... elif config.output_dir is not None: +... if config.output_dir is not None: ... os.makedirs(config.output_dir, exist_ok=True) +... if config.push_to_hub: +... repo_id = create_repo( +... repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True +... ) ... accelerator.init_trackers("train_example") ... # Prepare everything @@ -371,7 +361,12 @@ Now you can wrap all these components together in a training loop with 🤗 Acce ... if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1: ... if config.push_to_hub: -... repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=True) +... upload_folder( +... repo_id=repo_id, +... folder_path=config.output_dir, +... commit_message=f"Epoch {epoch}", +... ignore_patterns=["step_*", "epoch_*"], +... ) ... else: ... pipeline.save_pretrained(config.output_dir) ``` diff --git a/docs/source/ko/tutorials/basic_training.md b/docs/source/ko/tutorials/basic_training.md index e18c82c4fd4b..7bd5ad44dd08 100644 --- a/docs/source/ko/tutorials/basic_training.md +++ b/docs/source/ko/tutorials/basic_training.md @@ -283,36 +283,27 @@ TensorBoard에 로깅, 그래디언트 누적 및 혼합 정밀도 학습을 쉽 ```py >>> from accelerate import Accelerator ->>> from huggingface_hub import HfFolder, Repository, whoami +>>> from huggingface_hub import create_repo, upload_folder >>> from tqdm.auto import tqdm >>> from pathlib import Path >>> import os ->>> def get_full_repo_name(model_id: str, organization: str = None, token: str = None): -... if token is None: -... token = HfFolder.get_token() -... if organization is None: -... username = whoami(token)["name"] -... return f"{username}/{model_id}" -... else: -... return f"{organization}/{model_id}" - - >>> def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler): -... # accelerator와 tensorboard 로깅 초기화 +... # Initialize accelerator and tensorboard logging ... accelerator = Accelerator( ... mixed_precision=config.mixed_precision, ... gradient_accumulation_steps=config.gradient_accumulation_steps, ... log_with="tensorboard", -... logging_dir=os.path.join(config.output_dir, "logs"), +... project_dir=os.path.join(config.output_dir, "logs"), ... ) ... if accelerator.is_main_process: -... if config.push_to_hub: -... repo_name = get_full_repo_name(Path(config.output_dir).name) -... repo = Repository(config.output_dir, clone_from=repo_name) -... elif config.output_dir is not None: +... if config.output_dir is not None: ... os.makedirs(config.output_dir, exist_ok=True) +... if config.push_to_hub: +... repo_id = create_repo( +... repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True +... ) ... accelerator.init_trackers("train_example") ... # 모든 것이 준비되었습니다. @@ -369,7 +360,12 @@ TensorBoard에 로깅, 그래디언트 누적 및 혼합 정밀도 학습을 쉽 ... if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1: ... if config.push_to_hub: -... repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=True) +... upload_folder( +... repo_id=repo_id, +... folder_path=config.output_dir, +... commit_message=f"Epoch {epoch}", +... ignore_patterns=["step_*", "epoch_*"], +... ) ... else: ... pipeline.save_pretrained(config.output_dir) ``` diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py index 4ac4f969ee69..a436d36cebfd 100644 --- a/examples/dreambooth/train_dreambooth_flax.py +++ b/examples/dreambooth/train_dreambooth_flax.py @@ -4,7 +4,6 @@ import math import os from pathlib import Path -from typing import Optional import jax import jax.numpy as jnp @@ -16,7 +15,7 @@ from flax import jax_utils from flax.training import train_state from flax.training.common_utils import shard -from huggingface_hub import HfFolder, Repository, create_repo, whoami +from huggingface_hub import create_repo, upload_folder from jax.experimental.compilation_cache import compilation_cache as cc from PIL import Image from torch.utils.data import Dataset @@ -318,16 +317,6 @@ def __getitem__(self, index): return example -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - def get_params_to_save(params): return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params)) @@ -392,22 +381,14 @@ def main(): # Handle the repository creation if jax.process_index() == 0: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - create_repo(repo_name, exist_ok=True, token=args.hub_token) - repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + # Load the tokenizer and add the placeholder token as a additional special token if args.tokenizer_name: tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) @@ -668,7 +649,12 @@ def checkpoint(step=None): if args.push_to_hub: message = f"checkpoint-{step}" if step is not None else "End of training" - repo.push_to_hub(commit_message=message, blocking=False, auto_lfs_prune=True) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message=message, + ignore_patterns=["step_*", "epoch_*"], + ) global_step = 0 diff --git a/examples/research_projects/intel_opts/textual_inversion_dfq/textual_inversion.py b/examples/research_projects/intel_opts/textual_inversion_dfq/textual_inversion.py index b19dd6e1103d..43667187596e 100644 --- a/examples/research_projects/intel_opts/textual_inversion_dfq/textual_inversion.py +++ b/examples/research_projects/intel_opts/textual_inversion_dfq/textual_inversion.py @@ -4,7 +4,7 @@ import os import random from pathlib import Path -from typing import Iterable, Optional +from typing import Iterable import numpy as np import PIL @@ -13,7 +13,7 @@ import torch.utils.checkpoint from accelerate import Accelerator from accelerate.utils import ProjectConfiguration, set_seed -from huggingface_hub import HfFolder, Repository, whoami +from huggingface_hub import create_repo, upload_folder from neural_compressor.utils import logger from packaging import version from PIL import Image @@ -413,16 +413,6 @@ def __getitem__(self, i): return example -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - def freeze_params(params): for param in params: param.requires_grad = False @@ -461,21 +451,14 @@ def main(): # Handle the repository creation if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - repo = Repository(args.output_dir, clone_from=repo_name) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + # Load the tokenizer and add the placeholder token as a additional special token if args.tokenizer_name: tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) @@ -982,7 +965,12 @@ def attention_fetcher(x): ) if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) accelerator.end_training() diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py index ba5ccd238fdc..5cad9f2fbed9 100644 --- a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py +++ b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py @@ -4,7 +4,6 @@ import math import os from pathlib import Path -from typing import Optional import accelerate import datasets @@ -14,7 +13,7 @@ from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration from datasets import load_dataset -from huggingface_hub import HfFolder, Repository, create_repo, whoami +from huggingface_hub import create_repo, upload_folder from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer from onnxruntime.training.ortmodule import ORTModule from packaging import version @@ -277,16 +276,6 @@ def parse_args(): return args -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - def main(args): logging_dir = os.path.join(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration( @@ -360,22 +349,14 @@ def load_model_hook(models, input_dir): # Handle the repository creation if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - create_repo(repo_name, exist_ok=True, token=args.hub_token) - repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + # Initialize the model if args.model_config_name_or_path is None: model = UNet2DModel( @@ -691,7 +672,12 @@ def transform_images(examples): ema_model.restore(unet.parameters()) if args.push_to_hub: - repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message=f"Epoch {epoch}", + ignore_patterns=["step_*", "epoch_*"], + ) accelerator.end_training() diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 12b63439fa68..74b8ed106834 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -6,7 +6,6 @@ import shutil from datetime import timedelta from pathlib import Path -from typing import Optional import accelerate import datasets @@ -16,7 +15,7 @@ from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration from datasets import load_dataset -from huggingface_hub import HfFolder, Repository, create_repo, whoami +from huggingface_hub import create_repo, upload_folder from packaging import version from torchvision import transforms from tqdm.auto import tqdm @@ -273,16 +272,6 @@ def parse_args(): return args -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - def main(args): logging_dir = os.path.join(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) @@ -356,22 +345,14 @@ def load_model_hook(models, input_dir): # Handle the repository creation if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - create_repo(repo_name, exist_ok=True, token=args.hub_token) - repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + # Initialize the model if args.model_config_name_or_path is None: model = UNet2DModel( @@ -708,7 +689,12 @@ def transform_images(examples): ema_model.restore(unet.parameters()) if args.push_to_hub: - repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message=f"Epoch {epoch}", + ignore_patterns=["step_*", "epoch_*"], + ) accelerator.end_training() From de12776b3a7b3b8cf99940446e20eebf9710cd9d Mon Sep 17 00:00:00 2001 From: Gregg Helt Date: Mon, 16 Oct 2023 07:29:05 -0700 Subject: [PATCH 6/8] Add ability to mix usage of T2I-Adapter(s) and ControlNet(s). (#5362) * Add ability to mix usage of T2I-Adapter(s) and ControlNet(s). Previously, UNet2DConditional implemnetation onloy allowed use of one or the other. Adds new forward() arg down_intrablock_additional_residuals specifically for T2I-Adapters. If down_intrablock_addtional_residuals is not used, maintains backward compatibility with prior usage of only T2I-Adapter or ControlNet but not both * Improving forward() arg docs in src/diffusers/models/unet_2d_condition.py Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com> * Add deprecation warning if down_block_additional_residues is used for T2I-Adapter (intrablock residuals) Co-authored-by: Patrick von Platen * Oops my bad, fixing last commit. * Added import of diffusers utils.deprecate * Conform to max line length * Modifying T2I-Adapter pipelines to reflect change to UNet forward() arg for T2I-Adapter residuals. --------- Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Co-authored-by: Patrick von Platen --- src/diffusers/models/unet_2d_condition.py | 40 ++++++++++++++----- .../pipeline_stable_diffusion_adapter.py | 2 +- .../pipeline_stable_diffusion_xl_adapter.py | 6 +-- 3 files changed, 34 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 4039fbfcc67a..06421305c301 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -20,7 +20,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import UNet2DConditionLoadersMixin -from ..utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers +from ..utils import USE_PEFT_BACKEND, BaseOutput, logging, deprecate, scale_lora_layers, unscale_lora_layers from .activations import get_activation from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -778,6 +778,7 @@ def forward( added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: @@ -822,6 +823,13 @@ def forward( added_cond_kwargs: (`dict`, *optional*): A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. + down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added to UNet long skip connections from down blocks to up blocks + for example from ControlNet side model(s) + mid_block_additional_residual (`torch.Tensor`, *optional*): + additional residual to be added to UNet mid block output, for example from ControlNet side model + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: @@ -1000,15 +1008,28 @@ def forward( scale_lora_layers(self, lora_scale) is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None - is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate("T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: # For t2i-adapter CrossAttnDownBlock2D additional_residuals = {} - if is_adapter and len(down_block_additional_residuals) > 0: - additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) sample, res_samples = downsample_block( hidden_states=sample, @@ -1021,9 +1042,8 @@ def forward( ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) - - if is_adapter and len(down_block_additional_residuals) > 0: - sample += down_block_additional_residuals.pop(0) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) down_block_res_samples += res_samples @@ -1051,10 +1071,10 @@ def forward( # To support T2I-Adapter-XL if ( is_adapter - and len(down_block_additional_residuals) > 0 - and sample.shape == down_block_additional_residuals[0].shape + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape ): - sample += down_block_additional_residuals.pop(0) + sample += down_intrablock_additional_residuals.pop(0) if is_controlnet: sample = sample + mid_block_additional_residual diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index 0c7120c5b3ec..dca9e5fc3de2 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -813,7 +813,7 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, - down_block_additional_residuals=[state.clone() for state in adapter_state], + down_intrablock_additional_residuals=[state.clone() for state in adapter_state], ).sample # perform guidance diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index b31d478a9d67..d4272696c23b 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -975,9 +975,9 @@ def __call__( added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if i < int(num_inference_steps * adapter_conditioning_factor): - down_block_additional_residuals = [state.clone() for state in adapter_state] + down_intrablock_additional_residuals = [state.clone() for state in adapter_state] else: - down_block_additional_residuals = None + down_intrablock_additional_residuals = None noise_pred = self.unet( latent_model_input, @@ -986,7 +986,7 @@ def __call__( cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, - down_block_additional_residuals=down_block_additional_residuals, + down_intrablock_additional_residuals=down_intrablock_additional_residuals, )[0] # perform guidance From 57239dacd03c55ab6913efe0bf6ef7bb958f3c92 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 16 Oct 2023 16:29:50 +0200 Subject: [PATCH 7/8] make style --- src/diffusers/models/unet_2d_condition.py | 16 ++++---- .../versatile_diffusion/modeling_text_unet.py | 41 +++++++++++++++---- 2 files changed, 41 insertions(+), 16 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 06421305c301..0ce2e04ad99a 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -20,7 +20,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import UNet2DConditionLoadersMixin -from ..utils import USE_PEFT_BACKEND, BaseOutput, logging, deprecate, scale_lora_layers, unscale_lora_layers +from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers from .activations import get_activation from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -824,8 +824,8 @@ def forward( A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): - additional residuals to be added to UNet long skip connections from down blocks to up blocks - for example from ControlNet side model(s) + additional residuals to be added to UNet long skip connections from down blocks to up blocks for + example from ControlNet side model(s) mid_block_additional_residual (`torch.Tensor`, *optional*): additional residual to be added to UNet mid block output, for example from ControlNet side model down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): @@ -1014,12 +1014,14 @@ def forward( # T2I-Adapter and ControlNet both use down_block_additional_residuals arg # but can only use one or the other if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: - deprecate("T2I should not use down_block_additional_residuals", - "1.3.0", - "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", - standard_warn=False) + standard_warn=False, + ) down_intrablock_additional_residuals = down_block_additional_residuals is_adapter = True diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 2ed3deeb1225..a70903b4bd74 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -987,6 +987,7 @@ def forward( added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: @@ -1031,6 +1032,13 @@ def forward( added_cond_kwargs: (`dict`, *optional*): A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. + down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added to UNet long skip connections from down blocks to up blocks for + example from ControlNet side model(s) + mid_block_additional_residual (`torch.Tensor`, *optional*): + additional residual to be added to UNet mid block output, for example from ControlNet side model + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: @@ -1216,15 +1224,31 @@ def forward( scale_lora_layers(self, lora_scale) is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None - is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated " + " and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only" + " be used for ControlNet. Please make sure use" + " `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: # For t2i-adapter CrossAttnDownBlockFlat additional_residuals = {} - if is_adapter and len(down_block_additional_residuals) > 0: - additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) sample, res_samples = downsample_block( hidden_states=sample, @@ -1237,9 +1261,8 @@ def forward( ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) - - if is_adapter and len(down_block_additional_residuals) > 0: - sample += down_block_additional_residuals.pop(0) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) down_block_res_samples += res_samples @@ -1267,10 +1290,10 @@ def forward( # To support T2I-Adapter-XL if ( is_adapter - and len(down_block_additional_residuals) > 0 - and sample.shape == down_block_additional_residuals[0].shape + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape ): - sample += down_block_additional_residuals.pop(0) + sample += down_intrablock_additional_residuals.pop(0) if is_controlnet: sample = sample + mid_block_additional_residual From 8b3d2aeaf8ed1489752a9dc4ebf69e72c7af6bf0 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 17 Oct 2023 11:17:06 +0530 Subject: [PATCH 8/8] [Core] Fix/pipeline without text encoders for SDXL (#5301) * fix: sdxl pipeline when unet is not available. * fix moe * account for text * ifx more * don't make unet optional. * Apply suggestions from code review Co-authored-by: Patrick von Platen * split conditionals. * add optional components to sdxl pipeline * propagate changes to the rest of the pipelines. * add: test * add to all * fix: rest of the pipelines. * use pipeline_class variable * separate pipeline mixin * use safe_serialization * fix: test * access actual output. * add: optional test to adapter and ip2p sdxl pipeline tests/ * add optional test to controlnet sdxl. * fix tests * fix ip2p tests * fix more * fifx more. * use np output type. * fix for StableDiffusionXLMultiControlNetPipelineFastTests. * fix: SDXLOptionalComponentsTesterMixin * Apply suggestions from code review Co-authored-by: Patrick von Platen * fix tests * Empty-Commit * revert previous * quality * fix: test --------- Co-authored-by: Patrick von Platen --- .../pipeline_controlnet_inpaint_sd_xl.py | 62 ++++++-- .../controlnet/pipeline_controlnet_sd_xl.py | 67 +++++--- .../pipeline_controlnet_sd_xl_img2img.py | 55 +++++-- .../pipeline_stable_diffusion_xl.py | 63 ++++++-- .../pipeline_stable_diffusion_xl_img2img.py | 55 +++++-- .../pipeline_stable_diffusion_xl_inpaint.py | 54 +++++-- ...ne_stable_diffusion_xl_instruct_pix2pix.py | 38 ++++- .../pipeline_stable_diffusion_xl_adapter.py | 62 ++++++-- .../versatile_diffusion/modeling_text_unet.py | 2 + .../controlnet/test_controlnet_sdxl.py | 24 ++- .../test_stable_diffusion_xl.py | 11 +- .../test_stable_diffusion_xl_adapter.py | 13 +- .../test_stable_diffusion_xl_img2img.py | 7 +- ...stable_diffusion_xl_instruction_pix2pix.py | 16 +- tests/pipelines/test_pipelines_common.py | 144 ++++++++++++++++++ 15 files changed, 545 insertions(+), 128 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 4418ede74bd3..cf51fbe57180 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -168,7 +168,7 @@ class StableDiffusionXLControlNetInpaintPipeline( [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - _optional_components = ["tokenizer", "text_encoder"] + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, @@ -317,12 +317,17 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - scale_lora_layers(self.text_encoder_2, lora_scale) + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt @@ -438,7 +443,11 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -447,7 +456,12 @@ def encode_prompt( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -459,10 +473,15 @@ def encode_prompt( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder) - unscale_lora_layers(self.text_encoder_2) + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @@ -885,7 +904,14 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N return timesteps, num_inference_steps - t_start def _get_add_time_ids( - self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype, + text_encoder_projection_dim=None, ): if self.config.requires_aesthetics_score: add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) @@ -895,7 +921,7 @@ def _get_add_time_ids( add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -1391,6 +1417,11 @@ def denoising_value_valid(dnv): # 10. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids, add_neg_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, @@ -1398,6 +1429,7 @@ def denoising_value_valid(dnv): aesthetic_score, negative_aesthetic_score, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index f634f3f389a9..59573665867e 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -139,9 +139,9 @@ class StableDiffusionXLControlNetPipeline( watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no watermarker is used. """ - model_cpu_offload_seq = ( - "text_encoder->text_encoder_2->unet->vae" # leave controlnet out on purpose because it iterates with unet - ) + # leave controlnet out on purpose because it iterates with unet + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, @@ -285,12 +285,17 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - scale_lora_layers(self.text_encoder_2, lora_scale) + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt @@ -406,7 +411,11 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -415,7 +424,12 @@ def encode_prompt( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -427,10 +441,15 @@ def encode_prompt( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder) - unscale_lora_layers(self.text_encoder_2) + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @@ -706,11 +725,13 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype return latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids - def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -1088,8 +1109,17 @@ def __call__( target_size = target_size or (height, width) add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids = self._get_add_time_ids( - original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: @@ -1098,6 +1128,7 @@ def __call__( negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 3375855ba8ee..033544e893bd 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -183,7 +183,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( watermarker will be used. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - _optional_components = ["tokenizer", "text_encoder"] + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, @@ -329,12 +329,17 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - scale_lora_layers(self.text_encoder_2, lora_scale) + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt @@ -450,7 +455,11 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -459,7 +468,12 @@ def encode_prompt( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -471,10 +485,15 @@ def encode_prompt( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder) - unscale_lora_layers(self.text_encoder_2) + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @@ -832,6 +851,7 @@ def _get_add_time_ids( negative_crops_coords_top_left, negative_target_size, dtype, + text_encoder_projection_dim=None, ): if self.config.requires_aesthetics_score: add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) @@ -843,7 +863,7 @@ def _get_add_time_ids( add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -1275,6 +1295,12 @@ def __call__( if negative_target_size is None: negative_target_size = target_size add_text_embeds = pooled_prompt_embeds + + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids, add_neg_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, @@ -1285,6 +1311,7 @@ def __call__( negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 55bf929a2ee2..2658b58de240 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -140,6 +140,7 @@ class StableDiffusionXLPipeline( watermarker will be used. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, @@ -167,6 +168,7 @@ def __init__( self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = self.unet.config.sample_size add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() @@ -275,12 +277,17 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - scale_lora_layers(self.text_encoder_2, lora_scale) + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt @@ -396,7 +403,11 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -405,7 +416,12 @@ def encode_prompt( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -417,10 +433,15 @@ def encode_prompt( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder) - unscale_lora_layers(self.text_encoder_2) + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @@ -533,11 +554,13 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents - def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -843,8 +866,17 @@ def __call__( # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids = self._get_add_time_ids( - original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( @@ -852,6 +884,7 @@ def __call__( negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index b436f404d5ea..75eb02a48614 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -143,8 +143,7 @@ class StableDiffusionXLImg2ImgPipeline( watermarker will be used. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - - _optional_components = ["tokenizer", "text_encoder"] + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, @@ -282,12 +281,17 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - scale_lora_layers(self.text_encoder_2, lora_scale) + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt @@ -403,7 +407,11 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -412,7 +420,12 @@ def encode_prompt( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -424,10 +437,15 @@ def encode_prompt( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder) - unscale_lora_layers(self.text_encoder_2) + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @@ -618,6 +636,7 @@ def _get_add_time_ids( negative_crops_coords_top_left, negative_target_size, dtype, + text_encoder_projection_dim=None, ): if self.config.requires_aesthetics_score: add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) @@ -629,7 +648,7 @@ def _get_add_time_ids( add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -983,6 +1002,11 @@ def denoising_value_valid(dnv): negative_target_size = target_size add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids, add_neg_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, @@ -993,6 +1017,7 @@ def denoising_value_valid(dnv): negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index c04d2c0518c1..4af25afbeb3b 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -290,7 +290,7 @@ class StableDiffusionXLInpaintPipeline( """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - _optional_components = ["tokenizer", "text_encoder"] + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, @@ -431,12 +431,17 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - scale_lora_layers(self.text_encoder_2, lora_scale) + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt @@ -552,7 +557,11 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -561,7 +570,12 @@ def encode_prompt( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -573,10 +587,15 @@ def encode_prompt( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder) - unscale_lora_layers(self.text_encoder_2) + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @@ -836,6 +855,7 @@ def _get_add_time_ids( negative_crops_coords_top_left, negative_target_size, dtype, + text_encoder_projection_dim=None, ): if self.config.requires_aesthetics_score: add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) @@ -847,7 +867,7 @@ def _get_add_time_ids( add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -1289,6 +1309,11 @@ def denoising_value_valid(dnv): negative_target_size = target_size add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids, add_neg_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, @@ -1299,6 +1324,7 @@ def denoising_value_valid(dnv): negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 8cd7f46e633a..0427214f8374 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -31,11 +31,13 @@ from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( + USE_PEFT_BACKEND, deprecate, is_invisible_watermark_available, is_torch_xla_available, logging, replace_example_docstring, + scale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -150,6 +152,7 @@ class StableDiffusionXLInstructPix2PixPipeline( watermarker will be used. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, @@ -280,8 +283,17 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend) + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -390,7 +402,8 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + prompt_embeds_dtype = self.text_encoder_2.dtype if self.text_encoder_2 is not None else self.unet.dtype + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -399,7 +412,7 @@ def encode_prompt( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -552,11 +565,13 @@ def prepare_image_latents( return image_latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids - def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -871,8 +886,17 @@ def __call__( # 10. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids = self._get_add_time_ids( - original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) if do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index d4272696c23b..b606b9b50c31 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -160,6 +160,7 @@ class StableDiffusionXLAdapterPipeline( Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, @@ -290,12 +291,17 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - scale_lora_layers(self.text_encoder_2, lora_scale) + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt @@ -411,7 +417,11 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -420,7 +430,12 @@ def encode_prompt( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -432,10 +447,15 @@ def encode_prompt( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder) - unscale_lora_layers(self.text_encoder_2) + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @@ -550,11 +570,13 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype return latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids - def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -928,8 +950,17 @@ def __call__( adapter_state[k] = torch.cat([v] * 2, dim=0) add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids = self._get_add_time_ids( - original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( @@ -937,6 +968,7 @@ def __call__( negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index a70903b4bd74..717db3bbdb34 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -5,6 +5,8 @@ import torch.nn as nn import torch.nn.functional as F +from diffusers.utils import deprecate + from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin from ...models.activations import get_activation diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index 4fff88434bc3..be786ebe3000 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -42,6 +42,7 @@ PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, ) @@ -49,7 +50,11 @@ class StableDiffusionXLControlNetPipelineFastTests( - PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase + PipelineLatentTesterMixin, + PipelineKarrasSchedulerTesterMixin, + PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, + unittest.TestCase, ): pipeline_class = StableDiffusionXLControlNetPipeline params = TEXT_TO_IMAGE_PARAMS @@ -179,6 +184,9 @@ def test_xformers_attention_forwardGenerator_pass(self): def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) + def test_save_load_optional_components(self): + self._test_save_load_optional_components() + @require_torch_gpu def test_stable_diffusion_xl_offloads(self): pipes = [] @@ -324,7 +332,7 @@ def test_controlnet_sdxl_guess(self): class StableDiffusionXLMultiControlNetPipelineFastTests( - PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase + PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase ): pipeline_class = StableDiffusionXLControlNetPipeline params = TEXT_TO_IMAGE_PARAMS @@ -470,7 +478,7 @@ def get_dummy_inputs(self, device, seed=0): "generator": generator, "num_inference_steps": 2, "guidance_scale": 6.0, - "output_type": "numpy", + "output_type": "np", "image": images, } @@ -522,9 +530,12 @@ def test_xformers_attention_forwardGenerator_pass(self): def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) + def test_save_load_optional_components(self): + return self._test_save_load_optional_components() + class StableDiffusionXLMultiControlNetOneModelPipelineFastTests( - PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase + PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase ): pipeline_class = StableDiffusionXLControlNetPipeline params = TEXT_TO_IMAGE_PARAMS @@ -646,7 +657,7 @@ def get_dummy_inputs(self, device, seed=0): "generator": generator, "num_inference_steps": 2, "guidance_scale": 6.0, - "output_type": "numpy", + "output_type": "np", "image": images, } @@ -702,6 +713,9 @@ def test_xformers_attention_forwardGenerator_pass(self): def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) + def test_save_load_optional_components(self): + self._test_save_load_optional_components() + def test_negative_conditions(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index cebd860a4379..4906670890e8 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -35,13 +35,15 @@ from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin enable_full_determinism() -class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableDiffusionXLPipelineFastTests( + PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionXLPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -114,8 +116,6 @@ def get_dummy_components(self): "tokenizer": tokenizer, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, - # "safety_checker": None, - # "feature_extractor": None, } return components @@ -233,6 +233,9 @@ def test_attention_slicing_forward_pass(self): def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=3e-3) + def test_save_load_optional_components(self): + self._test_save_load_optional_components() + @require_torch_gpu def test_stable_diffusion_xl_offloads(self): pipes = [] diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py index 92c22ca2c34c..0e7a13bc876b 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py @@ -34,13 +34,19 @@ from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference +from ..test_pipelines_common import ( + PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, + assert_mean_pixel_difference, +) enable_full_determinism() -class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionXLAdapterPipelineFastTests( + PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionXLAdapterPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS @@ -215,6 +221,9 @@ def test_total_downscale_factor(self, adapter_type): expected_out_image_size, ) + def test_save_load_optional_components(self): + return self._test_save_load_optional_components() + class StableDiffusionXLMultiAdapterPipelineFastTests( StableDiffusionXLAdapterPipelineFastTests, PipelineTesterMixin, unittest.TestCase diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index ba7d3e8be30f..97c19108947f 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -38,7 +38,7 @@ TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS, ) -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin enable_full_determinism() @@ -341,7 +341,7 @@ def test_stable_diffusion_xl_img2img_negative_conditions(self): class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests( - PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase + PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase ): pipeline_class = StableDiffusionXLImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} @@ -600,3 +600,6 @@ def test_attention_slicing_forward_pass(self): def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=3e-3) + + def test_save_load_optional_components(self): + self._test_save_load_optional_components() diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py index ca4017d11b79..e20f8a0b54db 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py @@ -36,14 +36,23 @@ TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS, ) -from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import ( + PipelineKarrasSchedulerTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, +) enable_full_determinism() class StableDiffusionXLInstructPix2PixPipelineFastTests( - PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase + PipelineLatentTesterMixin, + PipelineKarrasSchedulerTesterMixin, + PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, + unittest.TestCase, ): pipeline_class = StableDiffusionXLInstructPix2PixPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "cross_attention_kwargs"} @@ -175,3 +184,6 @@ def test_latents_input(self): def test_cfg(self): pass + + def test_save_load_optional_components(self): + self._test_save_load_optional_components() diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 6f2674a7b8f6..ae13d0d3e9fa 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -974,6 +974,150 @@ def test_push_to_hub_in_organization(self): delete_repo(self.org_repo_id, token=TOKEN) +# For SDXL and its derivative pipelines (such as ControlNet), we have the text encoders +# and the tokenizers as optional components. So, we need to override the `test_save_load_optional_components()` +# test for all such pipelines. This requires us to use a custom `encode_prompt()` function. +class SDXLOptionalComponentsTesterMixin: + def encode_prompt( + self, tokenizers, text_encoders, prompt: str, num_images_per_prompt: int = 1, negative_prompt: str = None + ): + device = text_encoders[0].device + + if isinstance(prompt, str): + prompt = [prompt] + batch_size = len(prompt) + + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + if negative_prompt is None: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + else: + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + negative_prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder(uncond_input.input_ids.to(device), output_hidden_states=True) + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + bs_embed, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # for classifier-free guidance + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + # for classifier-free guidance + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def _test_save_load_optional_components(self, expected_max_difference=1e-4): + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + for optional_component in pipe._optional_components: + setattr(pipe, optional_component, None) + + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + + tokenizer = components.pop("tokenizer") + tokenizer_2 = components.pop("tokenizer_2") + text_encoder = components.pop("text_encoder") + text_encoder_2 = components.pop("text_encoder_2") + + tokenizers = [tokenizer, tokenizer_2] if tokenizer is not None else [tokenizer_2] + text_encoders = [text_encoder, text_encoder_2] if text_encoder is not None else [text_encoder_2] + prompt = inputs.pop("prompt") + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt(tokenizers, text_encoders, prompt) + inputs["prompt_embeds"] = prompt_embeds + inputs["negative_prompt_embeds"] = negative_prompt_embeds + inputs["pooled_prompt_embeds"] = pooled_prompt_embeds + inputs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds + + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for optional_component in pipe._optional_components: + self.assertTrue( + getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(generator_device) + _ = inputs.pop("prompt") + inputs["prompt_embeds"] = prompt_embeds + inputs["negative_prompt_embeds"] = negative_prompt_embeds + inputs["pooled_prompt_embeds"] = pooled_prompt_embeds + inputs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds + + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess(max_diff, expected_max_difference) + + # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a # reference image.