From 4eb297e676253b13a205512d9619de1e45c99bfe Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 2 Jan 2023 08:39:31 +0100 Subject: [PATCH 01/26] [Lora] first upload --- examples/lora/train_lora.py | 795 ++++++++++++++++++++++++++++++++++++ 1 file changed, 795 insertions(+) create mode 100644 examples/lora/train_lora.py diff --git a/examples/lora/train_lora.py b/examples/lora/train_lora.py new file mode 100644 index 000000000000..47ea768def0d --- /dev/null +++ b/examples/lora/train_lora.py @@ -0,0 +1,795 @@ +import argparse +import hashlib +import itertools +import math +import os +import warnings +from pathlib import Path +from typing import Optional + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.utils.data import Dataset + +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version +from diffusers.utils.import_utils import is_xformers_available +from huggingface_hub import HfFolder, Repository, whoami +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.10.0.dev0") + +logger = get_logger(__name__) + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel +# elif model_class == "RobertaSeriesModelWithTransformation": +# from diffusers.pipelines.lora.modeling_lora import RobertaSeriesModelWithTransformation +# +# return RobertaSeriesModelWithTransformation +# else: +# raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" + ) + parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + return example + + +def collate_fn(examples, with_prior_preservation=False): + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.cat(input_ids, dim=0) + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with="tensorboard", + logging_dir=logging_dir, + ) + + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate + # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. + # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) + + if args.seed is not None: + set_seed(args.seed) + + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer_name, + revision=args.revision, + use_fast=False, + ) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load models and create wrapper for stable diffusion + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + ) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + vae.requires_grad_(False) + if not args.train_text_encoder: + text_encoder.requires_grad_(False) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder.gradient_checkpointing_enable() + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + params_to_optimize = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() + ) + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=1, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + accelerator.register_for_checkpointing(lr_scheduler) + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move text_encode and vae to gpu. + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + vae.to(accelerator.device, dtype=weight_dtype) + if not args.train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = resume_global_step // num_update_steps_per_epoch + resume_step = resume_global_step % num_update_steps_per_epoch + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + if args.train_text_encoder: + text_encoder.train() + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + with accelerator.accumulate(unet): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + accelerator.wait_for_everyone() + + # Create the pipeline using using the trained modules and save it. + if accelerator.is_main_process: + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + revision=args.revision, + ) + pipeline.save_pretrained(args.output_dir) + + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) From 67f4e5a27f8d9644373f95ac3e06b36c9a69fcd9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 2 Jan 2023 13:21:09 +0100 Subject: [PATCH 02/26] add first lora version --- examples/lora/train_lora.py | 23 +++--- src/diffusers/models/cross_attention.py | 93 +++++++++++++++++++++++ src/diffusers/models/unet_2d_condition.py | 25 +++++- 3 files changed, 130 insertions(+), 11 deletions(-) diff --git a/examples/lora/train_lora.py b/examples/lora/train_lora.py index 47ea768def0d..381ed8d0126e 100644 --- a/examples/lora/train_lora.py +++ b/examples/lora/train_lora.py @@ -16,6 +16,7 @@ from accelerate.logging import get_logger from accelerate.utils import set_seed from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers.models.cross_attention import LoRACrossAttnProcessor, LoRAXFormersCrossAttnProcessor from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version from diffusers.utils.import_utils import is_xformers_available @@ -286,7 +287,7 @@ def parse_args(input_args=None): return args -class DreamBoothDataset(Dataset): +class LoRADataset(Dataset): """ A dataset to prepare the instance and class images with the prompts for fine-tuning the model. It pre-processes the images and the tokenizes prompts. @@ -535,20 +536,22 @@ def main(args): revision=args.revision, ) + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + unet.requires_grad_(False) + if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") - vae.requires_grad_(False) - if not args.train_text_encoder: - text_encoder.requires_grad_(False) + num_lora_layers = unet.num_attention_layers - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - if args.train_text_encoder: - text_encoder.gradient_checkpointing_enable() + if args.enable_xformers_memory_efficient_attention: + lora_attention_layers = [LoRAXFormersCrossAttnProcessor(query_dim, inner_dim, cross_attention_dim, rank=args.lora_rank) for _ in range(num_lora_layers)] + else: + lora_attention_layers = [LoRACrossAttnProcessor(query_dim, inner_dim, cross_attention_dim, rank=args.lora_rank) for _ in range(num_lora_layers)] if args.scale_lr: args.learning_rate = ( @@ -569,7 +572,7 @@ def main(args): optimizer_class = torch.optim.AdamW params_to_optimize = ( - itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() + itertools.chain([layer.parameters() for layer in lora_attention_layers]) ) optimizer = optimizer_class( params_to_optimize, @@ -581,7 +584,7 @@ def main(args): noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - train_dataset = DreamBoothDataset( + train_dataset = LoRADataset( instance_data_root=args.instance_data_dir, instance_prompt=args.instance_prompt, class_data_root=args.class_data_dir if args.with_prior_preservation else None, diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 98173cb8a406..70824b7d9b4f 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -237,6 +237,65 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No return hidden_states +class LoRALinearLayer(nn.Module): + def __init__(self, in_features, out_features, rank=4): + super().__init__() + + if rank > min(in_features, out_features): + raise ValueError( + f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}" + ) + + self.lora_down = nn.Linear(in_features, rank, bias=False) + self.lora_up = nn.Linear(rank, out_features, bias=False) + self.scale = 1.0 + + nn.init.normal_(self.lora_down.weight, std=1 / rank) + nn.init.zeros_(self.lora_up.weight) + + def forward(self, hidden_states): + down_hidden_states = self.lora_down(hidden_states) + up_hidden_states = self.lora_up(down_hidden_states) + + return up_hidden_states + + +class LoRACrossAttnProcessor(nn.Module): + def __init__(self, query_dim, inner_dim, cross_attention_dim, rank=4): + super().__init__() + + self.to_q_lora = LoRALinearLayer(query_dim, inner_dim) + self.to_k_lora = LoRALinearLayer(query_dim, inner_dim) + self.to_v_lora = LoRALinearLayer(query_dim, inner_dim) + self.to_out_lora = LoRALinearLayer(query_dim, inner_dim) + + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) + query = attn.head_to_batch_dim(query) + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + + key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) + + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + class CrossAttnAddedKVProcessor: def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): residual = hidden_states @@ -307,6 +366,40 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No return hidden_states +class LoRAXFormersCrossAttnProcessor(nn.Module): + def __init__(self, query_dim, inner_dim, cross_attention_dim, rank=4): + super().__init__() + + self.to_q_lora = LoRALinearLayer(query_dim, inner_dim) + self.to_k_lora = LoRALinearLayer(query_dim, inner_dim) + self.to_v_lora = LoRALinearLayer(query_dim, inner_dim) + self.to_out_lora = LoRALinearLayer(query_dim, inner_dim) + + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) + query = attn.head_to_batch_dim(query).contiguous() + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + + key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) + + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + class SlicedAttnProcessor: def __init__(self, slice_size): self.slice_size = slice_size diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 8099cd8421fb..dfceb3efd713 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -266,8 +266,31 @@ def __init__( self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) - def set_attn_processor(self, processor: AttnProcessor): + @property + def num_attention_layers(self): # set recursively + count = 0 + + def fn_recursive_count_processor(module: torch.nn.Module, count: int): + if hasattr(module, "set_processor"): + count += 1 + + for child in module.children(): + count = fn_recursive_count_processor(child) + + return count + + for module in self.children(): + count += fn_recursive_count_processor(module) + + return count + + def set_attn_processor(self, processor: Union[AttnProcessor, List[AttnProcessor]]): + count = self.num_attention_layers + + if isinstance(processor, list) and len(processor) != count: + raise ValueError(f"A list of processors was passed, but the number of processors {len(processor)} does not match the number of attention layers: {count}. Please make sure to pass {count} processor classes.") + def fn_recursive_attn_processor(module: torch.nn.Module): if hasattr(module, "set_processor"): module.set_processor(processor) From 24993c47aa9af0755ca12812afddcd433eeb4833 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 2 Jan 2023 13:29:28 +0100 Subject: [PATCH 03/26] upload --- examples/lora/train_lora.py | 2 +- src/diffusers/models/unet_2d_condition.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/lora/train_lora.py b/examples/lora/train_lora.py index 381ed8d0126e..6aa93d5a6cfb 100644 --- a/examples/lora/train_lora.py +++ b/examples/lora/train_lora.py @@ -546,7 +546,7 @@ def main(args): else: raise ValueError("xformers is not available. Make sure it is installed correctly") - num_lora_layers = unet.num_attention_layers + num_lora_layers = unet.num_attn_layers if args.enable_xformers_memory_efficient_attention: lora_attention_layers = [LoRAXFormersCrossAttnProcessor(query_dim, inner_dim, cross_attention_dim, rank=args.lora_rank) for _ in range(num_lora_layers)] diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index dfceb3efd713..a194796899fe 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -267,7 +267,7 @@ def __init__( self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) @property - def num_attention_layers(self): + def num_attn_layers(self): # set recursively count = 0 @@ -286,7 +286,7 @@ def fn_recursive_count_processor(module: torch.nn.Module, count: int): return count def set_attn_processor(self, processor: Union[AttnProcessor, List[AttnProcessor]]): - count = self.num_attention_layers + count = self.num_attn_layers if isinstance(processor, list) and len(processor) != count: raise ValueError(f"A list of processors was passed, but the number of processors {len(processor)} does not match the number of attention layers: {count}. Please make sure to pass {count} processor classes.") From 943e7f4dd293ac5cfd1d9c1be8edd6227bb5e52e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 2 Jan 2023 13:42:04 +0100 Subject: [PATCH 04/26] more --- examples/lora/train_lora.py | 13 +++++++++---- src/diffusers/models/unet_2d_condition.py | 2 +- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/examples/lora/train_lora.py b/examples/lora/train_lora.py index 6aa93d5a6cfb..8cd8a6aa2ee7 100644 --- a/examples/lora/train_lora.py +++ b/examples/lora/train_lora.py @@ -5,7 +5,7 @@ import os import warnings from pathlib import Path -from typing import Optional +from typing import Optional, List, Union import torch import torch.nn.functional as F @@ -546,12 +546,17 @@ def main(args): else: raise ValueError("xformers is not available. Make sure it is installed correctly") - num_lora_layers = unet.num_attn_layers + num_lora_layers = unet.num_attn_processors + + attention_head_dims: Union[List[int], int] = unet.config.attention_head_dim + + query_dim = unet.config.block_out_channels + cross_attention_dim = unet.config.cross_attention_dim if args.enable_xformers_memory_efficient_attention: - lora_attention_layers = [LoRAXFormersCrossAttnProcessor(query_dim, inner_dim, cross_attention_dim, rank=args.lora_rank) for _ in range(num_lora_layers)] + lora_attention_layers = [LoRAXFormersCrossAttnProcessor(query_dim, query_dim, cross_attention_dim, rank=args.lora_rank) for _ in range(num_lora_layers)] else: - lora_attention_layers = [LoRACrossAttnProcessor(query_dim, inner_dim, cross_attention_dim, rank=args.lora_rank) for _ in range(num_lora_layers)] + lora_attention_layers = [LoRACrossAttnProcessor(query_dim, query_dim, cross_attention_dim, rank=args.lora_rank) for _ in range(num_lora_layers)] if args.scale_lr: args.learning_rate = ( diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index a194796899fe..63e656bb59be 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -267,7 +267,7 @@ def __init__( self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) @property - def num_attn_layers(self): + def num_attn_processors(self): # set recursively count = 0 From e7293d06ad228d71add8505f64fc32fdfb4e7426 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 3 Jan 2023 10:06:56 +0000 Subject: [PATCH 05/26] first training --- examples/lora/README.md | 323 ++++++++++++++++++++++ examples/lora/train_lora.py | 49 ++-- src/diffusers/models/cross_attention.py | 20 +- src/diffusers/models/unet_2d_condition.py | 43 +-- 4 files changed, 389 insertions(+), 46 deletions(-) create mode 100644 examples/lora/README.md diff --git a/examples/lora/README.md b/examples/lora/README.md new file mode 100644 index 000000000000..2858c04c48b0 --- /dev/null +++ b/examples/lora/README.md @@ -0,0 +1,323 @@ +# DreamBooth training example + +[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. +The `train_dreambooth.py` script shows how to implement the training procedure and adapt it for stable diffusion. + + +## Running locally with PyTorch +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e . +``` + +Then cd in the example folder and run +```bash +pip install -r requirements.txt +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +Or for a default accelerate configuration without answering questions about your environment + +```bash +accelerate config default +``` + +Or if your environment doesn't support an interactive shell e.g. a notebook + +```python +from accelerate.utils import write_basic_config +write_basic_config() +``` + +### Dog toy example + +Now let's get our dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. This will be our training data. + +And launch the training using + +**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export INSTANCE_DIR="path-to-instance-images" +export OUTPUT_DIR="path-to-save-model" + +accelerate launch train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a photo of sks dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=1 \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=400 +``` + +### Training with prior-preservation loss + +Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data. +According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time. + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export INSTANCE_DIR="path-to-instance-images" +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="path-to-save-model" + +accelerate launch train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=1 \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --max_train_steps=800 +``` + + +### Training on a 16GB GPU: + +With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU. + +To install `bitandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation). + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export INSTANCE_DIR="path-to-instance-images" +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="path-to-save-model" + +accelerate launch train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=2 --gradient_checkpointing \ + --use_8bit_adam \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --max_train_steps=800 +``` + +### Training on a 8 GB GPU: + +By using [DeepSpeed](https://www.deepspeed.ai/) it's possible to offload some +tensors from VRAM to either CPU or NVME allowing to train with less VRAM. + +DeepSpeed needs to be enabled with `accelerate config`. During configuration +answer yes to "Do you want to use DeepSpeed?". With DeepSpeed stage 2, fp16 +mixed precision and offloading both parameters and optimizer state to cpu it's +possible to train on under 8 GB VRAM with a drawback of requiring significantly +more RAM (about 25 GB). See [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more DeepSpeed configuration options. + +Changing the default Adam optimizer to DeepSpeed's special version of Adam +`deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but enabling +it requires CUDA toolchain with the same version as pytorch. 8-bit optimizer +does not seem to be compatible with DeepSpeed at the moment. + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export INSTANCE_DIR="path-to-instance-images" +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="path-to-save-model" + +accelerate launch --mixed_precision="fp16" train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --sample_batch_size=1 \ + --gradient_accumulation_steps=1 --gradient_checkpointing \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --max_train_steps=800 +``` + +### Fine-tune text encoder with the UNet. + +The script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces. +Pass the `--train_text_encoder` argument to the script to enable training `text_encoder`. + +___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___ + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export INSTANCE_DIR="path-to-instance-images" +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="path-to-save-model" + +accelerate launch train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_text_encoder \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --use_8bit_adam \ + --gradient_checkpointing \ + --learning_rate=2e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --max_train_steps=800 +``` + +### Using DreamBooth for other pipelines than Stable Diffusion + +Altdiffusion also support dreambooth now, the runing comman is basically the same as abouve, all you need to do is replace the `MODEL_NAME` like this: +One can now simply change the `pretrained_model_name_or_path` to another architecture such as [`AltDiffusion`](https://huggingface.co/docs/diffusers/api/pipelines/alt_diffusion). + +``` +export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion-m9" +or +export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion" +``` + +### Inference + +Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt. + +```python +from diffusers import StableDiffusionPipeline +import torch + +model_id = "path-to-your-trained-model" +pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") + +prompt = "A photo of sks dog in a bucket" +image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] + +image.save("dog-bucket.png") +``` + +### Inference from a training checkpoint + +You can also perform inference from one of the checkpoints saved during the training process, if you used the `--checkpointing_steps` argument. Please, refer to [the documentation](https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint) to see how to do it. + +## Training with Flax/JAX + +For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script. + +____Note: The flax example don't yet support features like gradient checkpoint, gradient accumulation etc, so to use flax for faster training we will need >30GB cards.___ + + +Before running the scripts, make sure to install the library's training dependencies: + +```bash +pip install -U -r requirements_flax.txt +``` + + +### Training without prior preservation loss + +```bash +export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" +export INSTANCE_DIR="path-to-instance-images" +export OUTPUT_DIR="path-to-save-model" + +python train_dreambooth_flax.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a photo of sks dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --learning_rate=5e-6 \ + --max_train_steps=400 +``` + + +### Training with prior preservation loss + +```bash +export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" +export INSTANCE_DIR="path-to-instance-images" +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="path-to-save-model" + +python train_dreambooth_flax.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --learning_rate=5e-6 \ + --num_class_images=200 \ + --max_train_steps=800 +``` + + +### Fine-tune text encoder with the UNet. + +```bash +export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" +export INSTANCE_DIR="path-to-instance-images" +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="path-to-save-model" + +python train_dreambooth_flax.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_text_encoder \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --learning_rate=2e-6 \ + --num_class_images=200 \ + --max_train_steps=800 +``` + +### Training with xformers: +You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation. + +You can also use Dreambooth to train the specialized in-painting model. See [the script in the research folder for details](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/dreambooth_inpaint). diff --git a/examples/lora/train_lora.py b/examples/lora/train_lora.py index 8cd8a6aa2ee7..65a2af0726aa 100644 --- a/examples/lora/train_lora.py +++ b/examples/lora/train_lora.py @@ -5,7 +5,7 @@ import os import warnings from pathlib import Path -from typing import Optional, List, Union +from typing import Optional import torch import torch.nn.functional as F @@ -16,7 +16,7 @@ from accelerate.logging import get_logger from accelerate.utils import set_seed from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel -from diffusers.models.cross_attention import LoRACrossAttnProcessor, LoRAXFormersCrossAttnProcessor +from diffusers.models.cross_attention import LoRACrossAttnProcessor from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version from diffusers.utils.import_utils import is_xformers_available @@ -546,17 +546,35 @@ def main(args): else: raise ValueError("xformers is not available. Make sure it is installed correctly") - num_lora_layers = unet.num_attn_processors - - attention_head_dims: Union[List[int], int] = unet.config.attention_head_dim - - query_dim = unet.config.block_out_channels - cross_attention_dim = unet.config.cross_attention_dim - - if args.enable_xformers_memory_efficient_attention: - lora_attention_layers = [LoRAXFormersCrossAttnProcessor(query_dim, query_dim, cross_attention_dim, rank=args.lora_rank) for _ in range(num_lora_layers)] - else: - lora_attention_layers = [LoRACrossAttnProcessor(query_dim, query_dim, cross_attention_dim, rank=args.lora_rank) for _ in range(num_lora_layers)] + # now we will add new LoRA weights to the attention layers + # It's important to realize here how many attention weights will be added and of which sizes + # The sizes of the attention layers consist only of two different variables: + # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`. + # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`. + + # Let's first see how many attention processors we will have to set. + # For Stable Diffusion, it should be equal to: + # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12 + # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2 + # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18 + # => 32 layers + + # Set correct lora layers + lora_attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + + unet.set_attn_processor(lora_attn_procs) if args.scale_lr: args.learning_rate = ( @@ -576,9 +594,7 @@ def main(args): else: optimizer_class = torch.optim.AdamW - params_to_optimize = ( - itertools.chain([layer.parameters() for layer in lora_attention_layers]) - ) + params_to_optimize = itertools.chain(*[v.parameters() for v in unet.attn_processors.values()]) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, @@ -758,6 +774,7 @@ def main(args): else unet.parameters() ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() lr_scheduler.step() optimizer.zero_grad() diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 70824b7d9b4f..3faff598ddd1 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -261,13 +261,13 @@ def forward(self, hidden_states): class LoRACrossAttnProcessor(nn.Module): - def __init__(self, query_dim, inner_dim, cross_attention_dim, rank=4): + def __init__(self, hidden_size, cross_attention_dim=None, rank=4): super().__init__() - self.to_q_lora = LoRALinearLayer(query_dim, inner_dim) - self.to_k_lora = LoRALinearLayer(query_dim, inner_dim) - self.to_v_lora = LoRALinearLayer(query_dim, inner_dim) - self.to_out_lora = LoRALinearLayer(query_dim, inner_dim) + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): batch_size, sequence_length, _ = hidden_states.shape @@ -367,13 +367,13 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No class LoRAXFormersCrossAttnProcessor(nn.Module): - def __init__(self, query_dim, inner_dim, cross_attention_dim, rank=4): + def __init__(self, hidden_size, cross_attention_dim, rank=4): super().__init__() - self.to_q_lora = LoRALinearLayer(query_dim, inner_dim) - self.to_k_lora = LoRALinearLayer(query_dim, inner_dim) - self.to_v_lora = LoRALinearLayer(query_dim, inner_dim) - self.to_out_lora = LoRALinearLayer(query_dim, inner_dim) + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): batch_size, sequence_length, _ = hidden_states.shape diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 63e656bb59be..70f94664ca37 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -267,39 +267,42 @@ def __init__( self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) @property - def num_attn_processors(self): + def attn_processors(self) -> Dict[str, AttnProcessor]: # set recursively - count = 0 + processors = {} - def fn_recursive_count_processor(module: torch.nn.Module, count: int): + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]): if hasattr(module, "set_processor"): - count += 1 + processors[name] = module.processor - for child in module.children(): - count = fn_recursive_count_processor(child) + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - return count + return processors - for module in self.children(): - count += fn_recursive_count_processor(module) + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) - return count + return processors - def set_attn_processor(self, processor: Union[AttnProcessor, List[AttnProcessor]]): - count = self.num_attn_layers + def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]): + count = len(self.attn_processors.keys()) - if isinstance(processor, list) and len(processor) != count: - raise ValueError(f"A list of processors was passed, but the number of processors {len(processor)} does not match the number of attention layers: {count}. Please make sure to pass {count} processor classes.") + if isinstance(processor, dict) and len(processor) != count: + raise ValueError(f"A dict of processors was passed, but the number of processors {len(processor)} does not match the number of attention layers: {count}. Please make sure to pass {count} processor classes.") - def fn_recursive_attn_processor(module: torch.nn.Module): + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): - module.set_processor(processor) + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(name)) - for child in module.children(): - fn_recursive_attn_processor(child) + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - for module in self.children(): - fn_recursive_attn_processor(module) + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) def set_attention_slice(self, slice_size): r""" From b8e9ce40d9462b4caf39448d3589532edb87f46a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 3 Jan 2023 15:33:53 +0000 Subject: [PATCH 06/26] up --- examples/lora/train_lora.py | 112 +++++++++++++++++++++++++----------- 1 file changed, 77 insertions(+), 35 deletions(-) diff --git a/examples/lora/train_lora.py b/examples/lora/train_lora.py index 65a2af0726aa..f60c28520bbd 100644 --- a/examples/lora/train_lora.py +++ b/examples/lora/train_lora.py @@ -1,6 +1,7 @@ import argparse import hashlib import itertools +import logging import math import os import warnings @@ -12,12 +13,15 @@ import torch.utils.checkpoint from torch.utils.data import Dataset +import datasets +import diffusers +import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel -from diffusers.models.cross_attention import LoRACrossAttnProcessor from diffusers.optimization import get_scheduler +from diffusers.models.cross_attention import LoRACrossAttnProcessor from diffusers.utils import check_min_version from diffusers.utils.import_utils import is_xformers_available from huggingface_hub import HfFolder, Repository, whoami @@ -45,12 +49,12 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st from transformers import CLIPTextModel return CLIPTextModel -# elif model_class == "RobertaSeriesModelWithTransformation": -# from diffusers.pipelines.lora.modeling_lora import RobertaSeriesModelWithTransformation -# -# return RobertaSeriesModelWithTransformation -# else: -# raise ValueError(f"{model_class} is not supported.") + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") def parse_args(input_args=None): @@ -184,7 +188,7 @@ def parse_args(input_args=None): parser.add_argument( "--learning_rate", type=float, - default=5e-6, + default=5e-4, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -237,6 +241,23 @@ def parse_args(input_args=None): " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ), ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) parser.add_argument( "--mixed_precision", type=str, @@ -423,7 +444,7 @@ def main(args): accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, - log_with="tensorboard", + log_with=args.report_to, logging_dir=logging_dir, ) @@ -436,9 +457,27 @@ def main(args): "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." ) + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) + # Generate class images if prior preservation is enabled. if args.with_prior_preservation: class_images_dir = Path(args.class_data_dir) if not class_images_dir.exists(): @@ -503,11 +542,7 @@ def main(args): # Load the tokenizer if args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer_name, - revision=args.revision, - use_fast=False, - ) + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) elif args.pretrained_model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( args.pretrained_model_name_or_path, @@ -519,21 +554,14 @@ def main(args): # import correct text encoder class text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) - # Load models and create wrapper for stable diffusion + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") text_encoder = text_encoder_cls.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder", - revision=args.revision, - ) - vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="vae", - revision=args.revision, + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="unet", - revision=args.revision, + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) vae.requires_grad_(False) @@ -576,6 +604,21 @@ def main(args): unet.set_attn_processor(lora_attn_procs) + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder.gradient_checkpointing_enable() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes @@ -594,6 +637,7 @@ def main(args): else: optimizer_class = torch.optim.AdamW + # Optimizer creation params_to_optimize = itertools.chain(*[v.parameters() for v in unet.attn_processors.values()]) optimizer = optimizer_class( params_to_optimize, @@ -603,8 +647,7 @@ def main(args): eps=args.adam_epsilon, ) - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - + # Dataset and DataLoaders creation: train_dataset = LoRADataset( instance_data_root=args.instance_data_dir, instance_prompt=args.instance_prompt, @@ -639,6 +682,7 @@ def main(args): power=args.lr_power, ) + # Prepare everything with our `accelerator`. if args.train_text_encoder: unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, text_encoder, optimizer, train_dataloader, lr_scheduler @@ -647,17 +691,16 @@ def main(args): unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler ) - accelerator.register_for_checkpointing(lr_scheduler) + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - # Move text_encode and vae to gpu. - # For mixed precision training we cast the text_encoder and vae weights to half-precision - # as these models are only used for inference, keeping weights in full precision is not required. + # Move vae and text_encoder to device and cast to weight_dtype vae.to(accelerator.device, dtype=weight_dtype) if not args.train_text_encoder: text_encoder.to(accelerator.device, dtype=weight_dtype) @@ -688,6 +731,7 @@ def main(args): global_step = 0 first_epoch = 0 + # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) @@ -774,7 +818,6 @@ def main(args): else unet.parameters() ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - optimizer.step() lr_scheduler.step() optimizer.zero_grad() @@ -797,9 +840,8 @@ def main(args): if global_step >= args.max_train_steps: break - accelerator.wait_for_everyone() - # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() if accelerator.is_main_process: pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, From f7719e0d55075d480544096781edd410a912328f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 3 Jan 2023 18:20:10 +0000 Subject: [PATCH 07/26] correct --- examples/lora/train_lora.py | 30 +++++++++++++++++++++++ src/diffusers/models/unet_2d_condition.py | 5 ++++ 2 files changed, 35 insertions(+) diff --git a/examples/lora/train_lora.py b/examples/lora/train_lora.py index f60c28520bbd..54d232a5b0aa 100644 --- a/examples/lora/train_lora.py +++ b/examples/lora/train_lora.py @@ -106,6 +106,12 @@ def parse_args(input_args=None): default=None, help="The prompt to specify images in the same class as provided instance images.", ) + parser.add_argument( + "--save_sample_prompt", + type=str, + default=None, + help="A prompt that is sampled during training." + ) parser.add_argument( "--with_prior_preservation", default=False, @@ -840,6 +846,30 @@ def main(args): if global_step >= args.max_train_steps: break + if args.save_sample_prompt is not None: + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + revision=args.revision, + ) + pipeline.save_pretrained(args.output_dir) + pipeline = pipeline.to(accelerator.device) + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + pipeline.set_progress_bar_config(disable=True) + sample_dir = "/home/patrick_huggingface_co/lora-tryout/samples" + os.makedirs(sample_dir, exist_ok=True) + with torch.autocast("cuda"), torch.inference_mode(): + for i in tqdm(range(args.n_save_sample), desc="Generating samples"): + images = pipeline( + args.save_sample_prompt, + num_inference_steps=30, + generator=generator + ).images + images[0].save(os.path.join(sample_dir, f"{i}.png")) + del pipeline + torch.cuda.empty_cache() + # Create the pipeline using using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 70f94664ca37..9d4e5767dbdb 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -123,6 +123,7 @@ def __init__( num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", + attn_processor_cls=None, ): super().__init__() @@ -266,6 +267,10 @@ def __init__( self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + if attn_processor_cls is not None: + attn_processors_keys = self.attn_processors.keys() + self.set_attn_processor({k: attn_processor_cls() for k in attn_processors_keys}) + @property def attn_processors(self) -> Dict[str, AttnProcessor]: # set recursively From b69f2762ed9cb83c36a4278c96c015f3372c1087 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 13 Jan 2023 15:26:23 +0000 Subject: [PATCH 08/26] improve --- examples/lora/train_lora.py | 41 +++++++++++++++-------- src/diffusers/models/cross_attention.py | 12 ++++--- src/diffusers/models/unet_2d_condition.py | 5 ++- 3 files changed, 38 insertions(+), 20 deletions(-) diff --git a/examples/lora/train_lora.py b/examples/lora/train_lora.py index 54d232a5b0aa..b9ed0fa9ac51 100644 --- a/examples/lora/train_lora.py +++ b/examples/lora/train_lora.py @@ -29,6 +29,9 @@ from torchvision import transforms from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig +import wandb + +wandb.login() # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -36,6 +39,10 @@ logger = get_logger(__name__) +run = wandb.init(project="stable_diffusion_lora") + +generated_table = wandb.Table(columns=["gen_num", "prompt", "generated_images"]) + def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( @@ -107,10 +114,7 @@ def parse_args(input_args=None): help="The prompt to specify images in the same class as provided instance images.", ) parser.add_argument( - "--save_sample_prompt", - type=str, - default=None, - help="A prompt that is sampled during training." + "--save_sample_prompt", type=str, default=None, help="A prompt that is sampled during training." ) parser.add_argument( "--with_prior_preservation", @@ -606,7 +610,9 @@ def main(args): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + lora_attn_procs[name] = LoRACrossAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + ) unet.set_attn_processor(lora_attn_procs) @@ -846,7 +852,8 @@ def main(args): if global_step >= args.max_train_steps: break - if args.save_sample_prompt is not None: + if args.save_sample_prompt is not None and epoch % 10 == 0: + print("Running inference...") pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), @@ -859,17 +866,23 @@ def main(args): pipeline.set_progress_bar_config(disable=True) sample_dir = "/home/patrick_huggingface_co/lora-tryout/samples" os.makedirs(sample_dir, exist_ok=True) - with torch.autocast("cuda"), torch.inference_mode(): - for i in tqdm(range(args.n_save_sample), desc="Generating samples"): - images = pipeline( - args.save_sample_prompt, - num_inference_steps=30, - generator=generator - ).images - images[0].save(os.path.join(sample_dir, f"{i}.png")) + + for i in tqdm(range(5), desc="Generating samples"): + prompt = args.save_sample_prompt + images = pipeline(prompt, num_inference_steps=30, generator=generator).images + image = images[0] + + image.save(os.path.join(sample_dir, f"{i}.png")) + + global_step = epoch * len(train_dataloader) + i + generated_table.add_data(global_step, prompt, wandb.Image(image)) + run.log({"generated_image": wandb.Image(image)}) + del pipeline torch.cuda.empty_cache() + run.log({"generated_table": generated_table}) + # Create the pipeline using using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 3faff598ddd1..9747ff09822d 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -242,9 +242,7 @@ def __init__(self, in_features, out_features, rank=4): super().__init__() if rank > min(in_features, out_features): - raise ValueError( - f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}" - ) + raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") self.lora_down = nn.Linear(in_features, rank, bias=False) self.lora_up = nn.Linear(rank, out_features, bias=False) @@ -269,7 +267,9 @@ def __init__(self, hidden_size, cross_attention_dim=None, rank=4): self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): + def __call__( + self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 + ): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) @@ -375,7 +375,9 @@ def __init__(self, hidden_size, cross_attention_dim, rank=4): self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): + def __call__( + self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 + ): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 9d4e5767dbdb..0d9b71f377f7 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -294,7 +294,10 @@ def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProce count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: - raise ValueError(f"A dict of processors was passed, but the number of processors {len(processor)} does not match the number of attention layers: {count}. Please make sure to pass {count} processor classes.") + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): From 5d6ee568516715833bd9bc801ec12157f9828a93 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 13 Jan 2023 17:52:56 +0000 Subject: [PATCH 09/26] finish loaders and inference --- examples/lora/train_lora.py | 89 ++++++++++--------- src/diffusers/pipelines/loaders.py | 38 ++++++++ .../pipeline_stable_diffusion.py | 3 +- 3 files changed, 87 insertions(+), 43 deletions(-) create mode 100644 src/diffusers/pipelines/loaders.py diff --git a/examples/lora/train_lora.py b/examples/lora/train_lora.py index b9ed0fa9ac51..1ba288a17aa3 100644 --- a/examples/lora/train_lora.py +++ b/examples/lora/train_lora.py @@ -1,12 +1,11 @@ import argparse import hashlib -import itertools import logging import math import os import warnings from pathlib import Path -from typing import Optional +from typing import Optional, Dict import torch import torch.nn.functional as F @@ -34,6 +33,36 @@ wandb.login() +class LoraLayers(torch.nn.Module): + def __init__(self, state_dict: Dict[str, torch.Tensor]): + super().__init__() + self.layers = torch.nn.ModuleList(state_dict.values()) + self.mapping = {k: v for k, v in enumerate(state_dict.keys())} + self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} + + # we add a hook to state_dict() and load_state_dict() so that the + # naming fits with `unet.attn_processors` + def map_to(module, state_dict, *args, **kwargs): + new_state_dict = {} + for key, value in state_dict.items(): + num = int(key.split(".")[1]) # 0 is always "layers" + new_key = key.replace(f"layers.{num}", module.mapping[num]) + new_state_dict[new_key] = value + + return new_state_dict + + def map_from(module, state_dict, *args, **kwargs): + all_keys = list(state_dict.keys()) + for key in all_keys: + replace_key = ".".join(key.split(".")[:-3]) + new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") + state_dict[new_key] = state_dict[key] + del state_dict[key] + + self._register_state_dict_hook(map_to) + self._register_load_state_dict_pre_hook(map_from, with_module=True) + + # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.10.0.dev0") @@ -151,7 +180,6 @@ def parse_args(input_args=None): parser.add_argument( "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" ) - parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") parser.add_argument( "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) @@ -461,12 +489,6 @@ def main(args): # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. - if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: - raise ValueError( - "Gradient accumulation is not supported when training the text encoder in distributed training. " - "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." - ) - # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -615,17 +637,18 @@ def main(args): ) unet.set_attn_processor(lora_attn_procs) + lora_layers = LoraLayers(unet.attn_processors) + + state_dict = lora_layers.state_dict() + lora_layers.load_state_dict(state_dict) + + accelerator.register_for_checkpointing(lora_layers) if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - if args.train_text_encoder: - text_encoder.gradient_checkpointing_enable() - # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: @@ -650,9 +673,8 @@ def main(args): optimizer_class = torch.optim.AdamW # Optimizer creation - params_to_optimize = itertools.chain(*[v.parameters() for v in unet.attn_processors.values()]) optimizer = optimizer_class( - params_to_optimize, + lora_layers.parameters(), lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, @@ -695,14 +717,9 @@ def main(args): ) # Prepare everything with our `accelerator`. - if args.train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler - ) - else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, optimizer, train_dataloader, lr_scheduler - ) + lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + lora_layers, optimizer, train_dataloader, lr_scheduler + ) # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. @@ -712,10 +729,10 @@ def main(args): elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - # Move vae and text_encoder to device and cast to weight_dtype + # Move unet, vae and text_encoder to device and cast to weight_dtype + unet.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) - if not args.train_text_encoder: - text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -767,8 +784,6 @@ def main(args): for epoch in range(first_epoch, args.num_train_epochs): unet.train() - if args.train_text_encoder: - text_encoder.train() for step, batch in enumerate(train_dataloader): # Skip steps until we reach the resumed step if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: @@ -824,11 +839,7 @@ def main(args): accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = ( - itertools.chain(unet.parameters(), text_encoder.parameters()) - if args.train_text_encoder - else unet.parameters() - ) + params_to_clip = lora_layers.parameters() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() @@ -883,16 +894,10 @@ def main(args): run.log({"generated_table": generated_table}) - # Create the pipeline using using the trained modules and save it. + # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: - pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), - text_encoder=accelerator.unwrap_model(text_encoder), - revision=args.revision, - ) - pipeline.save_pretrained(args.output_dir) + torch.save(lora_layers.state_dict(), os.path.join(args.output_dir, "lora_layers.bin")) if args.push_to_hub: repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) diff --git a/src/diffusers/pipelines/loaders.py b/src/diffusers/pipelines/loaders.py new file mode 100644 index 000000000000..e61c2681c993 --- /dev/null +++ b/src/diffusers/pipelines/loaders.py @@ -0,0 +1,38 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ..models.cross_attention import LoRACrossAttnProcessor +from collections import defaultdict +import torch + + +class LoraUNetLoaderMixin: + def load_lora(self, pretrained_model_name_or_path): + state_dict = torch.load(pretrained_model_name_or_path, map_location="cpu") + lora_grouped_dict = defaultdict(dict) + for key, value in state_dict.items(): + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + lora_grouped_dict[attn_processor_key][sub_key] = value + + attn_processors = {} + for key, value_dict in lora_grouped_dict.items(): + rank = value_dict["to_k_lora.lora_down.weight"].shape[0] + cross_attention_dim = value_dict["to_k_lora.lora_down.weight"].shape[1] + hidden_size = value_dict["to_k_lora.lora_up.weight"].shape[0] + + attn_processors[key] = LoRACrossAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank + ) + attn_processors[key].load_state_dict(value_dict) + + self.unet.set_attn_processor(attn_processors) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index edfc8eaf7a52..4e3484fba36c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -32,6 +32,7 @@ ) from ...utils import deprecate, is_accelerate_available, logging, replace_example_docstring from ..pipeline_utils import DiffusionPipeline +from ..loaders import LoraUNetLoaderMixin from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -53,7 +54,7 @@ """ -class StableDiffusionPipeline(DiffusionPipeline): +class StableDiffusionPipeline(DiffusionPipeline, LoraUNetLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. From bc15289f2ef56caa6cb062d80215ce3f63ef31a0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 13 Jan 2023 18:44:53 +0000 Subject: [PATCH 10/26] up --- src/diffusers/models/modeling_utils.py | 94 ++------------------------ src/diffusers/pipelines/loaders.py | 53 ++++++++++++++- 2 files changed, 56 insertions(+), 91 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 91c44973b34f..152c8f98b16f 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -435,7 +435,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model_file = None if is_safetensors_available(): try: - model_file = cls._get_model_file( + model_file = _get_model_file( pretrained_model_name_or_path, weights_name=SAFETENSORS_WEIGHTS_NAME, cache_dir=cache_dir, @@ -451,7 +451,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P except: pass if model_file is None: - model_file = cls._get_model_file( + model_file = _get_model_file( pretrained_model_name_or_path, weights_name=WEIGHTS_NAME, cache_dir=cache_dir, @@ -556,92 +556,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P return model - @classmethod - def _get_model_file( - cls, - pretrained_model_name_or_path, - *, - weights_name, - subfolder, - cache_dir, - force_download, - proxies, - resume_download, - local_files_only, - use_auth_token, - user_agent, - revision, - ): - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if os.path.isdir(pretrained_model_name_or_path): - if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): - # Load from a PyTorch checkpoint - model_file = os.path.join(pretrained_model_name_or_path, weights_name) - elif subfolder is not None and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, weights_name) - ): - model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name) - else: - raise EnvironmentError( - f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}." - ) - return model_file - else: - try: - # Load from URL or cache if already cached - model_file = hf_hub_download( - pretrained_model_name_or_path, - filename=weights_name, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - user_agent=user_agent, - subfolder=subfolder, - revision=revision, - ) - return model_file - - except RepositoryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " - "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " - "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " - "login`." - ) - except RevisionNotFoundError: - raise EnvironmentError( - f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " - "this model name. Check the model page at " - f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." - ) - except EntryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}." - ) - except HTTPError as err: - raise EnvironmentError( - "There was a specific connection error when trying to load" - f" {pretrained_model_name_or_path}:\n{err}" - ) - except ValueError: - raise EnvironmentError( - f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" - f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" - f" directory containing a file named {weights_name} or" - " \nCheckout your internet connection or see how to run the library in" - " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." - ) - except EnvironmentError: - raise EnvironmentError( - f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " - "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " - f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " - f"containing a file named {weights_name}" - ) - @classmethod def _load_pretrained_model( cls, @@ -805,7 +719,9 @@ def _get_model_file( revision, ): pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(pretrained_model_name_or_path): + return pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): # Load from a PyTorch checkpoint model_file = os.path.join(pretrained_model_name_or_path, weights_name) diff --git a/src/diffusers/pipelines/loaders.py b/src/diffusers/pipelines/loaders.py index e61c2681c993..d39914fb00e3 100644 --- a/src/diffusers/pipelines/loaders.py +++ b/src/diffusers/pipelines/loaders.py @@ -12,13 +12,59 @@ # See the License for the specific language governing permissions and # limitations under the License. from ..models.cross_attention import LoRACrossAttnProcessor +from ..models.modeling_utils import _get_model_file from collections import defaultdict import torch +from ..utils import ( + DIFFUSERS_CACHE, + HF_HUB_OFFLINE, + logging, +) + + +logger = logging.get_logger(__name__) + + +LORA_WEIGHT_NAME = "pytorch_lora.bin" class LoraUNetLoaderMixin: - def load_lora(self, pretrained_model_name_or_path): - state_dict = torch.load(pretrained_model_name_or_path, map_location="cpu") + def load_lora(self, pretrained_model_name_or_path, **kwargs): + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + + user_agent = { + "file_type": "lora", + "framework": "pytorch", + } + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=LORA_WEIGHT_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + + state_dict = torch.load(model_file, map_location="cpu") lora_grouped_dict = defaultdict(dict) for key, value in state_dict.items(): attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) @@ -35,4 +81,7 @@ def load_lora(self, pretrained_model_name_or_path): ) attn_processors[key].load_state_dict(value_dict) + if torch_dtype is not None: + attn_processors[key].to(torch_dtype) + self.unet.set_attn_processor(attn_processors) From d334d5a6d8f4177b246d347fb77224a9613f189c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 16 Jan 2023 15:51:08 +0000 Subject: [PATCH 11/26] up --- tests/models/test_models_unet_2d.py | 545 +--------------- tests/models/test_models_unet_2d_condition.py | 581 ++++++++++++++++++ 2 files changed, 582 insertions(+), 544 deletions(-) create mode 100644 tests/models/test_models_unet_2d_condition.py diff --git a/tests/models/test_models_unet_2d.py b/tests/models/test_models_unet_2d.py index 91192f17fb00..803a9e88d0cb 100644 --- a/tests/models/test_models_unet_2d.py +++ b/tests/models/test_models_unet_2d.py @@ -20,19 +20,14 @@ import torch -from diffusers import UNet2DConditionModel, UNet2DModel +from diffusers import UNet2DModel from diffusers.utils import ( floats_tensor, - load_hf_numpy, logging, - require_torch_gpu, slow, torch_all_close, torch_device, ) -from diffusers.utils.import_utils import is_xformers_available -from parameterized import parameterized - from ..test_modeling_common import ModelTesterMixin @@ -218,237 +213,6 @@ def test_output_pretrained(self): self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3)) -class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): - model_class = UNet2DConditionModel - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 4 - sizes = (32, 32) - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device) - - return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} - - @property - def input_shape(self): - return (4, 32, 32) - - @property - def output_shape(self): - return (4, 32, 32) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "block_out_channels": (32, 64), - "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), - "up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"), - "cross_attention_dim": 32, - "attention_head_dim": 8, - "out_channels": 4, - "in_channels": 4, - "layers_per_block": 2, - "sample_size": 32, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - @unittest.skipIf( - torch_device != "cuda" or not is_xformers_available(), - reason="XFormers attention is only available with CUDA and `xformers` installed", - ) - def test_xformers_enable_works(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - - model.enable_xformers_memory_efficient_attention() - - assert ( - model.mid_block.attentions[0].transformer_blocks[0].attn1._use_memory_efficient_attention_xformers - ), "xformers is not enabled" - - @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS") - def test_gradient_checkpointing(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - - assert not model.is_gradient_checkpointing and model.training - - out = model(**inputs_dict).sample - # run the backwards pass on the model. For backwards pass, for simplicity purpose, - # we won't calculate the loss and rather backprop on out.sum() - model.zero_grad() - - labels = torch.randn_like(out) - loss = (out - labels).mean() - loss.backward() - - # re-instantiate the model now enabling gradient checkpointing - model_2 = self.model_class(**init_dict) - # clone model - model_2.load_state_dict(model.state_dict()) - model_2.to(torch_device) - model_2.enable_gradient_checkpointing() - - assert model_2.is_gradient_checkpointing and model_2.training - - out_2 = model_2(**inputs_dict).sample - # run the backwards pass on the model. For backwards pass, for simplicity purpose, - # we won't calculate the loss and rather backprop on out.sum() - model_2.zero_grad() - loss_2 = (out_2 - labels).mean() - loss_2.backward() - - # compare the output and parameters gradients - self.assertTrue((loss - loss_2).abs() < 1e-5) - named_params = dict(model.named_parameters()) - named_params_2 = dict(model_2.named_parameters()) - for name, param in named_params.items(): - self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) - - def test_model_with_attention_head_dim_tuple(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = (8, 16) - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.sample - - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - def test_model_with_use_linear_projection(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["use_linear_projection"] = True - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.sample - - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - def test_model_attention_slicing(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = (8, 16) - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - model.set_attention_slice("auto") - with torch.no_grad(): - output = model(**inputs_dict) - assert output is not None - - model.set_attention_slice("max") - with torch.no_grad(): - output = model(**inputs_dict) - assert output is not None - - model.set_attention_slice(2) - with torch.no_grad(): - output = model(**inputs_dict) - assert output is not None - - def test_model_slicable_head_dim(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = (8, 16) - - model = self.model_class(**init_dict) - - def check_slicable_dim_attr(module: torch.nn.Module): - if hasattr(module, "set_attention_slice"): - assert isinstance(module.sliceable_head_dim, int) - - for child in module.children(): - check_slicable_dim_attr(child) - - # retrieve number of attention layers - for module in model.children(): - check_slicable_dim_attr(module) - - def test_special_attn_proc(self): - class AttnEasyProc(torch.nn.Module): - def __init__(self, num): - super().__init__() - self.weight = torch.nn.Parameter(torch.tensor(num)) - self.is_run = False - self.number = 0 - self.counter = 0 - - def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None): - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - query = attn.to_q(hidden_states) - - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - hidden_states += self.weight - - self.is_run = True - self.counter += 1 - self.number = number - - return hidden_states - - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = (8, 16) - - model = self.model_class(**init_dict) - model.to(torch_device) - - processor = AttnEasyProc(5.0) - - model.set_attn_processor(processor) - model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample - - assert processor.counter == 12 - assert processor.is_run - assert processor.number == 123 - - class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): model_class = UNet2DModel @@ -564,310 +328,3 @@ def test_output_pretrained_ve_large(self): def test_forward_with_norm_groups(self): # not required for this model pass - - -@slow -class UNet2DConditionModelIntegrationTests(unittest.TestCase): - def get_file_format(self, seed, shape): - return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" - - def tearDown(self): - # clean up the VRAM after each test - super().tearDown() - gc.collect() - torch.cuda.empty_cache() - - def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False): - dtype = torch.float16 if fp16 else torch.float32 - image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype) - return image - - def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"): - revision = "fp16" if fp16 else None - torch_dtype = torch.float16 if fp16 else torch.float32 - - model = UNet2DConditionModel.from_pretrained( - model_id, subfolder="unet", torch_dtype=torch_dtype, revision=revision - ) - model.to(torch_device).eval() - - return model - - def test_set_attention_slice_auto(self): - torch.cuda.empty_cache() - torch.cuda.reset_max_memory_allocated() - torch.cuda.reset_peak_memory_stats() - - unet = self.get_unet_model() - unet.set_attention_slice("auto") - - latents = self.get_latents(33) - encoder_hidden_states = self.get_encoder_hidden_states(33) - timestep = 1 - - with torch.no_grad(): - _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample - - mem_bytes = torch.cuda.max_memory_allocated() - - assert mem_bytes < 5 * 10**9 - - def test_set_attention_slice_max(self): - torch.cuda.empty_cache() - torch.cuda.reset_max_memory_allocated() - torch.cuda.reset_peak_memory_stats() - - unet = self.get_unet_model() - unet.set_attention_slice("max") - - latents = self.get_latents(33) - encoder_hidden_states = self.get_encoder_hidden_states(33) - timestep = 1 - - with torch.no_grad(): - _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample - - mem_bytes = torch.cuda.max_memory_allocated() - - assert mem_bytes < 5 * 10**9 - - def test_set_attention_slice_int(self): - torch.cuda.empty_cache() - torch.cuda.reset_max_memory_allocated() - torch.cuda.reset_peak_memory_stats() - - unet = self.get_unet_model() - unet.set_attention_slice(2) - - latents = self.get_latents(33) - encoder_hidden_states = self.get_encoder_hidden_states(33) - timestep = 1 - - with torch.no_grad(): - _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample - - mem_bytes = torch.cuda.max_memory_allocated() - - assert mem_bytes < 5 * 10**9 - - def test_set_attention_slice_list(self): - torch.cuda.empty_cache() - torch.cuda.reset_max_memory_allocated() - torch.cuda.reset_peak_memory_stats() - - # there are 32 slicable layers - slice_list = 16 * [2, 3] - unet = self.get_unet_model() - unet.set_attention_slice(slice_list) - - latents = self.get_latents(33) - encoder_hidden_states = self.get_encoder_hidden_states(33) - timestep = 1 - - with torch.no_grad(): - _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample - - mem_bytes = torch.cuda.max_memory_allocated() - - assert mem_bytes < 5 * 10**9 - - def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False): - dtype = torch.float16 if fp16 else torch.float32 - hidden_states = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype) - return hidden_states - - @parameterized.expand( - [ - # fmt: off - [33, 4, [-0.4424, 0.1510, -0.1937, 0.2118, 0.3746, -0.3957, 0.0160, -0.0435]], - [47, 0.55, [-0.1508, 0.0379, -0.3075, 0.2540, 0.3633, -0.0821, 0.1719, -0.0207]], - [21, 0.89, [-0.6479, 0.6364, -0.3464, 0.8697, 0.4443, -0.6289, -0.0091, 0.1778]], - [9, 1000, [0.8888, -0.5659, 0.5834, -0.7469, 1.1912, -0.3923, 1.1241, -0.4424]], - # fmt: on - ] - ) - @require_torch_gpu - def test_compvis_sd_v1_4(self, seed, timestep, expected_slice): - model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4") - latents = self.get_latents(seed) - encoder_hidden_states = self.get_encoder_hidden_states(seed) - - timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) - - with torch.no_grad(): - sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample - - assert sample.shape == latents.shape - - output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() - expected_output_slice = torch.tensor(expected_slice) - - assert torch_all_close(output_slice, expected_output_slice, atol=1e-3) - - @parameterized.expand( - [ - # fmt: off - [83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]], - [17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]], - [8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]], - [3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]], - # fmt: on - ] - ) - @require_torch_gpu - def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice): - model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True) - latents = self.get_latents(seed, fp16=True) - encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) - - timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) - - with torch.no_grad(): - sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample - - assert sample.shape == latents.shape - - output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() - expected_output_slice = torch.tensor(expected_slice) - - assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) - - @parameterized.expand( - [ - # fmt: off - [33, 4, [-0.4430, 0.1570, -0.1867, 0.2376, 0.3205, -0.3681, 0.0525, -0.0722]], - [47, 0.55, [-0.1415, 0.0129, -0.3136, 0.2257, 0.3430, -0.0536, 0.2114, -0.0436]], - [21, 0.89, [-0.7091, 0.6664, -0.3643, 0.9032, 0.4499, -0.6541, 0.0139, 0.1750]], - [9, 1000, [0.8878, -0.5659, 0.5844, -0.7442, 1.1883, -0.3927, 1.1192, -0.4423]], - # fmt: on - ] - ) - @require_torch_gpu - def test_compvis_sd_v1_5(self, seed, timestep, expected_slice): - model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5") - latents = self.get_latents(seed) - encoder_hidden_states = self.get_encoder_hidden_states(seed) - - timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) - - with torch.no_grad(): - sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample - - assert sample.shape == latents.shape - - output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() - expected_output_slice = torch.tensor(expected_slice) - - assert torch_all_close(output_slice, expected_output_slice, atol=1e-3) - - @parameterized.expand( - [ - # fmt: off - [83, 4, [-0.2695, -0.1669, 0.0073, -0.3181, -0.1187, -0.1676, -0.1395, -0.5972]], - [17, 0.55, [-0.1290, -0.2588, 0.0551, -0.0916, 0.3286, 0.0238, -0.3669, 0.0322]], - [8, 0.89, [-0.5283, 0.1198, 0.0870, -0.1141, 0.9189, -0.0150, 0.5474, 0.4319]], - [3, 1000, [-0.5601, 0.2411, -0.5435, 0.1268, 1.1338, -0.2427, -0.0280, -1.0020]], - # fmt: on - ] - ) - @require_torch_gpu - def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice): - model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5", fp16=True) - latents = self.get_latents(seed, fp16=True) - encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) - - timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) - - with torch.no_grad(): - sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample - - assert sample.shape == latents.shape - - output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() - expected_output_slice = torch.tensor(expected_slice) - - assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) - - @parameterized.expand( - [ - # fmt: off - [33, 4, [-0.7639, 0.0106, -0.1615, -0.3487, -0.0423, -0.7972, 0.0085, -0.4858]], - [47, 0.55, [-0.6564, 0.0795, -1.9026, -0.6258, 1.8235, 1.2056, 1.2169, 0.9073]], - [21, 0.89, [0.0327, 0.4399, -0.6358, 0.3417, 0.4120, -0.5621, -0.0397, -1.0430]], - [9, 1000, [0.1600, 0.7303, -1.0556, -0.3515, -0.7440, -1.2037, -1.8149, -1.8931]], - # fmt: on - ] - ) - @require_torch_gpu - def test_compvis_sd_inpaint(self, seed, timestep, expected_slice): - model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting") - latents = self.get_latents(seed, shape=(4, 9, 64, 64)) - encoder_hidden_states = self.get_encoder_hidden_states(seed) - - timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) - - with torch.no_grad(): - sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample - - assert sample.shape == (4, 4, 64, 64) - - output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() - expected_output_slice = torch.tensor(expected_slice) - - assert torch_all_close(output_slice, expected_output_slice, atol=1e-3) - - @parameterized.expand( - [ - # fmt: off - [83, 4, [-0.1047, -1.7227, 0.1067, 0.0164, -0.5698, -0.4172, -0.1388, 1.1387]], - [17, 0.55, [0.0975, -0.2856, -0.3508, -0.4600, 0.3376, 0.2930, -0.2747, -0.7026]], - [8, 0.89, [-0.0952, 0.0183, -0.5825, -0.1981, 0.1131, 0.4668, -0.0395, -0.3486]], - [3, 1000, [0.4790, 0.4949, -1.0732, -0.7158, 0.7959, -0.9478, 0.1105, -0.9741]], - # fmt: on - ] - ) - @require_torch_gpu - def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice): - model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting", fp16=True) - latents = self.get_latents(seed, shape=(4, 9, 64, 64), fp16=True) - encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) - - timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) - - with torch.no_grad(): - sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample - - assert sample.shape == (4, 4, 64, 64) - - output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() - expected_output_slice = torch.tensor(expected_slice) - - assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) - - @parameterized.expand( - [ - # fmt: off - [83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]], - [17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]], - [8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]], - [3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]], - # fmt: on - ] - ) - @require_torch_gpu - def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice): - model = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True) - latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True) - encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True) - - timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) - - with torch.no_grad(): - sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample - - assert sample.shape == latents.shape - - output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() - expected_output_slice = torch.tensor(expected_slice) - - assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py new file mode 100644 index 000000000000..3cf3095c1477 --- /dev/null +++ b/tests/models/test_models_unet_2d_condition.py @@ -0,0 +1,581 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import torch + +from diffusers import UNet2DConditionModel +from diffusers.utils import ( + floats_tensor, + load_hf_numpy, + logging, + require_torch_gpu, + slow, + torch_all_close, + torch_device, +) +from diffusers.utils.import_utils import is_xformers_available +from parameterized import parameterized + +from ..test_modeling_common import ModelTesterMixin + + +logger = logging.get_logger(__name__) +torch.backends.cuda.matmul.allow_tf32 = False + + +class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNet2DConditionModel + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 4 + sizes = (32, 32) + + noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device) + + return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + + @property + def input_shape(self): + return (4, 32, 32) + + @property + def output_shape(self): + return (4, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": (32, 64), + "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), + "up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"), + "cross_attention_dim": 32, + "attention_head_dim": 8, + "out_channels": 4, + "in_channels": 4, + "layers_per_block": 2, + "sample_size": 32, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_enable_works(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + + model.enable_xformers_memory_efficient_attention() + + assert ( + model.mid_block.attentions[0].transformer_blocks[0].attn1._use_memory_efficient_attention_xformers + ), "xformers is not enabled" + + def test_set_attention_processors(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + + @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS") + def test_gradient_checkpointing(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + + assert not model.is_gradient_checkpointing and model.training + + out = model(**inputs_dict).sample + # run the backwards pass on the model. For backwards pass, for simplicity purpose, + # we won't calculate the loss and rather backprop on out.sum() + model.zero_grad() + + labels = torch.randn_like(out) + loss = (out - labels).mean() + loss.backward() + + # re-instantiate the model now enabling gradient checkpointing + model_2 = self.model_class(**init_dict) + # clone model + model_2.load_state_dict(model.state_dict()) + model_2.to(torch_device) + model_2.enable_gradient_checkpointing() + + assert model_2.is_gradient_checkpointing and model_2.training + + out_2 = model_2(**inputs_dict).sample + # run the backwards pass on the model. For backwards pass, for simplicity purpose, + # we won't calculate the loss and rather backprop on out.sum() + model_2.zero_grad() + loss_2 = (out_2 - labels).mean() + loss_2.backward() + + # compare the output and parameters gradients + self.assertTrue((loss - loss_2).abs() < 1e-5) + named_params = dict(model.named_parameters()) + named_params_2 = dict(model_2.named_parameters()) + for name, param in named_params.items(): + self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) + + def test_model_with_attention_head_dim_tuple(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_model_with_use_linear_projection(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["use_linear_projection"] = True + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_model_attention_slicing(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + model.set_attention_slice("auto") + with torch.no_grad(): + output = model(**inputs_dict) + assert output is not None + + model.set_attention_slice("max") + with torch.no_grad(): + output = model(**inputs_dict) + assert output is not None + + model.set_attention_slice(2) + with torch.no_grad(): + output = model(**inputs_dict) + assert output is not None + + def test_model_slicable_head_dim(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + + def check_slicable_dim_attr(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + assert isinstance(module.sliceable_head_dim, int) + + for child in module.children(): + check_slicable_dim_attr(child) + + # retrieve number of attention layers + for module in model.children(): + check_slicable_dim_attr(module) + + def test_special_attn_proc(self): + class AttnEasyProc(torch.nn.Module): + def __init__(self, num): + super().__init__() + self.weight = torch.nn.Parameter(torch.tensor(num)) + self.is_run = False + self.number = 0 + self.counter = 0 + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + query = attn.to_q(hidden_states) + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states += self.weight + + self.is_run = True + self.counter += 1 + self.number = number + + return hidden_states + + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + + processor = AttnEasyProc(5.0) + + model.set_attn_processor(processor) + model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample + + assert processor.counter == 12 + assert processor.is_run + assert processor.number == 123 + + +@slow +class UNet2DConditionModelIntegrationTests(unittest.TestCase): + def get_file_format(self, seed, shape): + return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" + + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False): + dtype = torch.float16 if fp16 else torch.float32 + image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype) + return image + + def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"): + revision = "fp16" if fp16 else None + torch_dtype = torch.float16 if fp16 else torch.float32 + + model = UNet2DConditionModel.from_pretrained( + model_id, subfolder="unet", torch_dtype=torch_dtype, revision=revision + ) + model.to(torch_device).eval() + + return model + + def test_set_attention_slice_auto(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + unet = self.get_unet_model() + unet.set_attention_slice("auto") + + latents = self.get_latents(33) + encoder_hidden_states = self.get_encoder_hidden_states(33) + timestep = 1 + + with torch.no_grad(): + _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + mem_bytes = torch.cuda.max_memory_allocated() + + assert mem_bytes < 5 * 10**9 + + def test_set_attention_slice_max(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + unet = self.get_unet_model() + unet.set_attention_slice("max") + + latents = self.get_latents(33) + encoder_hidden_states = self.get_encoder_hidden_states(33) + timestep = 1 + + with torch.no_grad(): + _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + mem_bytes = torch.cuda.max_memory_allocated() + + assert mem_bytes < 5 * 10**9 + + def test_set_attention_slice_int(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + unet = self.get_unet_model() + unet.set_attention_slice(2) + + latents = self.get_latents(33) + encoder_hidden_states = self.get_encoder_hidden_states(33) + timestep = 1 + + with torch.no_grad(): + _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + mem_bytes = torch.cuda.max_memory_allocated() + + assert mem_bytes < 5 * 10**9 + + def test_set_attention_slice_list(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + # there are 32 slicable layers + slice_list = 16 * [2, 3] + unet = self.get_unet_model() + unet.set_attention_slice(slice_list) + + latents = self.get_latents(33) + encoder_hidden_states = self.get_encoder_hidden_states(33) + timestep = 1 + + with torch.no_grad(): + _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + mem_bytes = torch.cuda.max_memory_allocated() + + assert mem_bytes < 5 * 10**9 + + def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False): + dtype = torch.float16 if fp16 else torch.float32 + hidden_states = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype) + return hidden_states + + @parameterized.expand( + [ + # fmt: off + [33, 4, [-0.4424, 0.1510, -0.1937, 0.2118, 0.3746, -0.3957, 0.0160, -0.0435]], + [47, 0.55, [-0.1508, 0.0379, -0.3075, 0.2540, 0.3633, -0.0821, 0.1719, -0.0207]], + [21, 0.89, [-0.6479, 0.6364, -0.3464, 0.8697, 0.4443, -0.6289, -0.0091, 0.1778]], + [9, 1000, [0.8888, -0.5659, 0.5834, -0.7469, 1.1912, -0.3923, 1.1241, -0.4424]], + # fmt: on + ] + ) + @require_torch_gpu + def test_compvis_sd_v1_4(self, seed, timestep, expected_slice): + model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4") + latents = self.get_latents(seed) + encoder_hidden_states = self.get_encoder_hidden_states(seed) + + timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) + + with torch.no_grad(): + sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + assert sample.shape == latents.shape + + output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=1e-3) + + @parameterized.expand( + [ + # fmt: off + [83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]], + [17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]], + [8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]], + [3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]], + # fmt: on + ] + ) + @require_torch_gpu + def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice): + model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True) + latents = self.get_latents(seed, fp16=True) + encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) + + timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) + + with torch.no_grad(): + sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + assert sample.shape == latents.shape + + output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) + + @parameterized.expand( + [ + # fmt: off + [33, 4, [-0.4430, 0.1570, -0.1867, 0.2376, 0.3205, -0.3681, 0.0525, -0.0722]], + [47, 0.55, [-0.1415, 0.0129, -0.3136, 0.2257, 0.3430, -0.0536, 0.2114, -0.0436]], + [21, 0.89, [-0.7091, 0.6664, -0.3643, 0.9032, 0.4499, -0.6541, 0.0139, 0.1750]], + [9, 1000, [0.8878, -0.5659, 0.5844, -0.7442, 1.1883, -0.3927, 1.1192, -0.4423]], + # fmt: on + ] + ) + @require_torch_gpu + def test_compvis_sd_v1_5(self, seed, timestep, expected_slice): + model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5") + latents = self.get_latents(seed) + encoder_hidden_states = self.get_encoder_hidden_states(seed) + + timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) + + with torch.no_grad(): + sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + assert sample.shape == latents.shape + + output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=1e-3) + + @parameterized.expand( + [ + # fmt: off + [83, 4, [-0.2695, -0.1669, 0.0073, -0.3181, -0.1187, -0.1676, -0.1395, -0.5972]], + [17, 0.55, [-0.1290, -0.2588, 0.0551, -0.0916, 0.3286, 0.0238, -0.3669, 0.0322]], + [8, 0.89, [-0.5283, 0.1198, 0.0870, -0.1141, 0.9189, -0.0150, 0.5474, 0.4319]], + [3, 1000, [-0.5601, 0.2411, -0.5435, 0.1268, 1.1338, -0.2427, -0.0280, -1.0020]], + # fmt: on + ] + ) + @require_torch_gpu + def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice): + model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5", fp16=True) + latents = self.get_latents(seed, fp16=True) + encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) + + timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) + + with torch.no_grad(): + sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + assert sample.shape == latents.shape + + output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) + + @parameterized.expand( + [ + # fmt: off + [33, 4, [-0.7639, 0.0106, -0.1615, -0.3487, -0.0423, -0.7972, 0.0085, -0.4858]], + [47, 0.55, [-0.6564, 0.0795, -1.9026, -0.6258, 1.8235, 1.2056, 1.2169, 0.9073]], + [21, 0.89, [0.0327, 0.4399, -0.6358, 0.3417, 0.4120, -0.5621, -0.0397, -1.0430]], + [9, 1000, [0.1600, 0.7303, -1.0556, -0.3515, -0.7440, -1.2037, -1.8149, -1.8931]], + # fmt: on + ] + ) + @require_torch_gpu + def test_compvis_sd_inpaint(self, seed, timestep, expected_slice): + model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting") + latents = self.get_latents(seed, shape=(4, 9, 64, 64)) + encoder_hidden_states = self.get_encoder_hidden_states(seed) + + timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) + + with torch.no_grad(): + sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + assert sample.shape == (4, 4, 64, 64) + + output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=1e-3) + + @parameterized.expand( + [ + # fmt: off + [83, 4, [-0.1047, -1.7227, 0.1067, 0.0164, -0.5698, -0.4172, -0.1388, 1.1387]], + [17, 0.55, [0.0975, -0.2856, -0.3508, -0.4600, 0.3376, 0.2930, -0.2747, -0.7026]], + [8, 0.89, [-0.0952, 0.0183, -0.5825, -0.1981, 0.1131, 0.4668, -0.0395, -0.3486]], + [3, 1000, [0.4790, 0.4949, -1.0732, -0.7158, 0.7959, -0.9478, 0.1105, -0.9741]], + # fmt: on + ] + ) + @require_torch_gpu + def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice): + model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting", fp16=True) + latents = self.get_latents(seed, shape=(4, 9, 64, 64), fp16=True) + encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) + + timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) + + with torch.no_grad(): + sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + assert sample.shape == (4, 4, 64, 64) + + output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) + + @parameterized.expand( + [ + # fmt: off + [83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]], + [17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]], + [8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]], + [3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]], + # fmt: on + ] + ) + @require_torch_gpu + def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice): + model = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True) + latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True) + encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True) + + timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) + + with torch.no_grad(): + sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + assert sample.shape == latents.shape + + output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) From d8f1a6b354a961bea33ad9d3cc988efd40d5a4a3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 16 Jan 2023 19:35:11 +0100 Subject: [PATCH 12/26] fix more --- examples/lora/train_lora.py | 7 +- .../textual_inversion_bf16.py | 5 +- .../textual_inversion/textual_inversion.py | 5 +- .../textual_inversion_flax.py | 5 +- src/diffusers/dependency_versions_table.py | 2 +- src/diffusers/loaders.py | 185 ++++++ src/diffusers/models/attention.py | 6 +- src/diffusers/models/cross_attention.py | 6 +- src/diffusers/models/modeling_utils.py | 4 +- src/diffusers/models/unet_2d_condition.py | 8 +- src/diffusers/pipelines/loaders.py | 87 --- .../pipeline_stable_diffusion.py | 3 +- .../scheduling_euler_ancestral_discrete.py | 8 +- .../schedulers/scheduling_euler_discrete.py | 8 +- tests/models/_ | 528 ++++++++++++++++++ tests/models/test_models_unet_2d.py | 9 +- tests/models/test_models_unet_2d_condition.py | 117 +++- tests/test_scheduler.py | 6 +- 18 files changed, 852 insertions(+), 147 deletions(-) create mode 100644 src/diffusers/loaders.py delete mode 100644 src/diffusers/pipelines/loaders.py create mode 100644 tests/models/_ diff --git a/examples/lora/train_lora.py b/examples/lora/train_lora.py index 1ba288a17aa3..a5d40831ba34 100644 --- a/examples/lora/train_lora.py +++ b/examples/lora/train_lora.py @@ -5,7 +5,7 @@ import os import warnings from pathlib import Path -from typing import Optional, Dict +from typing import Dict, Optional import torch import torch.nn.functional as F @@ -15,12 +15,13 @@ import datasets import diffusers import transformers +import wandb from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel -from diffusers.optimization import get_scheduler from diffusers.models.cross_attention import LoRACrossAttnProcessor +from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version from diffusers.utils.import_utils import is_xformers_available from huggingface_hub import HfFolder, Repository, whoami @@ -28,7 +29,7 @@ from torchvision import transforms from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig -import wandb + wandb.login() diff --git a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py index a9b663b2e68c..4a7540aa161b 100644 --- a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py +++ b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py @@ -336,10 +336,7 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - ( - h, - w, - ) = ( + (h, w,) = ( img.shape[0], img.shape[1], ) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 11e145e63a3f..a9f766ac79c0 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -381,10 +381,7 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - ( - h, - w, - ) = ( + (h, w,) = ( img.shape[0], img.shape[1], ) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index 1001126402ec..2fb961f9a3f9 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -306,10 +306,7 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - ( - h, - w, - ) = ( + (h, w,) = ( img.shape[0], img.shape[1], ) diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 1ef1edc14629..7fc779fc543e 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -4,7 +4,7 @@ deps = { "Pillow": "Pillow", "accelerate": "accelerate>=0.11.0", - "black": "black==22.8", + "black": "black==22.12", "datasets": "datasets", "filelock": "filelock", "flake8": "flake8>=3.8.3", diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py new file mode 100644 index 000000000000..eb3d1f46982f --- /dev/null +++ b/src/diffusers/loaders.py @@ -0,0 +1,185 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from collections import defaultdict +from typing import Callable, Dict, Union + +import torch + +from .models.cross_attention import LoRACrossAttnProcessor +from .models.modeling_utils import _get_model_file +from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, logging + + +logger = logging.get_logger(__name__) + + +ATTN_WEIGHT_NAME = "pytorch_attn_procs.bin" + + +class AttnProcsLayers(torch.nn.Module): + def __init__(self, state_dict: Dict[str, torch.Tensor]): + super().__init__() + self.layers = torch.nn.ModuleList(state_dict.values()) + self.mapping = {k: v for k, v in enumerate(state_dict.keys())} + self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} + + # we add a hook to state_dict() and load_state_dict() so that the + # naming fits with `unet.attn_processors` + def map_to(module, state_dict, *args, **kwargs): + new_state_dict = {} + for key, value in state_dict.items(): + num = int(key.split(".")[1]) # 0 is always "layers" + new_key = key.replace(f"layers.{num}", module.mapping[num]) + new_state_dict[new_key] = value + + return new_state_dict + + def map_from(module, state_dict, *args, **kwargs): + all_keys = list(state_dict.keys()) + for key in all_keys: + replace_key = ".".join(key.split(".processor.")[1:]) + new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") + state_dict[new_key] = state_dict[key] + del state_dict[key] + + self._register_state_dict_hook(map_to) + self._register_load_state_dict_pre_hook(map_from, with_module=True) + + +class AttnProcsLoader: + def load_attn_procs(self, pretrained_model_name_or_path, **kwargs): + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", ATTN_WEIGHT_NAME) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + + state_dict = torch.load(model_file, map_location="cpu") + + # fill attn processors + attn_processors = {} + + is_lora = all("lora" in k for k in state_dict.keys()) + + if is_lora: + lora_grouped_dict = defaultdict(dict) + for key, value in state_dict.items(): + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + lora_grouped_dict[attn_processor_key][sub_key] = value + + for key, value_dict in lora_grouped_dict.items(): + rank = value_dict["to_k_lora.lora_down.weight"].shape[0] + cross_attention_dim = value_dict["to_k_lora.lora_down.weight"].shape[1] + hidden_size = value_dict["to_k_lora.lora_up.weight"].shape[0] + + attn_processors[key] = LoRACrossAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank + ) + attn_processors[key].load_state_dict(value_dict) + + else: + raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.") + + # dtype + if torch_dtype is not None: + attn_processors = {k: v.to(torch_dtype) for k, v in attn_processors.items()} + + # device + attn_processors = {k: v.to(self.device) for k, v in attn_processors.items()} + + # set layers + self.set_attn_processor(attn_processors) + + def save_attn_procs( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + weights_name: str = ATTN_WEIGHT_NAME, + save_function: Callable = None, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + `[`~models.ModelMixin.from_pretrained`]` class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `False`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + """ + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + if save_function is None: + save_function = torch.save + + os.makedirs(save_directory, exist_ok=True) + + model_to_save = AttnProcsLayers(self.attn_processors) + + # Save the model + state_dict = model_to_save.state_dict() + + # Clean the folder from a previous save + for filename in os.listdir(save_directory): + full_filename = os.path.join(save_directory, filename) + # If we have a shard file that is not going to be replaced, we delete it, but only from the main process + # in distributed settings to avoid race conditions. + weights_no_suffix = weights_name.replace(".bin", "") + if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process: + os.remove(full_filename) + + # Save the model + save_function(state_dict, os.path.join(save_directory, weights_name)) + + logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}") diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 85dcc800fd1e..ffe67987467b 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -90,10 +90,8 @@ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_atten if use_memory_efficient_attention_xformers: if not is_xformers_available(): raise ModuleNotFoundError( - ( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - " xformers" - ), + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", name="xformers", ) elif not torch.cuda.is_available(): diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 2a95865a2a1c..82363fc1d287 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -105,10 +105,8 @@ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_atten ) elif not is_xformers_available(): raise ModuleNotFoundError( - ( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - " xformers" - ), + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", name="xformers", ) elif not torch.cuda.is_available(): diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index e02c13668677..1822bb49d561 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -474,7 +474,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P else: if is_safetensors_available(): try: - model_file = cls._get_model_file( + model_file = _get_model_file( pretrained_model_name_or_path, weights_name=SAFETENSORS_WEIGHTS_NAME, cache_dir=cache_dir, @@ -490,7 +490,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P except: pass if model_file is None: - model_file = cls._get_model_file( + model_file = _get_model_file( pretrained_model_name_or_path, weights_name=WEIGHTS_NAME, cache_dir=cache_dir, diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 0d9b71f377f7..aa1225f527d6 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -19,6 +19,7 @@ import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import AttnProcsLoader from ..utils import BaseOutput, logging from .cross_attention import AttnProcessor from .embeddings import TimestepEmbedding, Timesteps @@ -49,7 +50,7 @@ class UNet2DConditionOutput(BaseOutput): sample: torch.FloatTensor -class UNet2DConditionModel(ModelMixin, ConfigMixin): +class UNet2DConditionModel(ModelMixin, ConfigMixin, AttnProcsLoader): r""" UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep and returns sample shaped output. @@ -123,7 +124,6 @@ def __init__( num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", - attn_processor_cls=None, ): super().__init__() @@ -267,10 +267,6 @@ def __init__( self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) - if attn_processor_cls is not None: - attn_processors_keys = self.attn_processors.keys() - self.set_attn_processor({k: attn_processor_cls() for k in attn_processors_keys}) - @property def attn_processors(self) -> Dict[str, AttnProcessor]: # set recursively diff --git a/src/diffusers/pipelines/loaders.py b/src/diffusers/pipelines/loaders.py deleted file mode 100644 index d39914fb00e3..000000000000 --- a/src/diffusers/pipelines/loaders.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from ..models.cross_attention import LoRACrossAttnProcessor -from ..models.modeling_utils import _get_model_file -from collections import defaultdict -import torch -from ..utils import ( - DIFFUSERS_CACHE, - HF_HUB_OFFLINE, - logging, -) - - -logger = logging.get_logger(__name__) - - -LORA_WEIGHT_NAME = "pytorch_lora.bin" - - -class LoraUNetLoaderMixin: - def load_lora(self, pretrained_model_name_or_path, **kwargs): - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - torch_dtype = kwargs.pop("torch_dtype", None) - subfolder = kwargs.pop("subfolder", None) - - user_agent = { - "file_type": "lora", - "framework": "pytorch", - } - - if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): - raise ValueError( - f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." - ) - - model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=LORA_WEIGHT_NAME, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - - state_dict = torch.load(model_file, map_location="cpu") - lora_grouped_dict = defaultdict(dict) - for key, value in state_dict.items(): - attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) - lora_grouped_dict[attn_processor_key][sub_key] = value - - attn_processors = {} - for key, value_dict in lora_grouped_dict.items(): - rank = value_dict["to_k_lora.lora_down.weight"].shape[0] - cross_attention_dim = value_dict["to_k_lora.lora_down.weight"].shape[1] - hidden_size = value_dict["to_k_lora.lora_up.weight"].shape[0] - - attn_processors[key] = LoRACrossAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank - ) - attn_processors[key].load_state_dict(value_dict) - - if torch_dtype is not None: - attn_processors[key].to(torch_dtype) - - self.unet.set_attn_processor(attn_processors) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 1d97226c0d09..c3b4b905e0d2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -32,7 +32,6 @@ ) from ...utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline -from ..loaders import LoraUNetLoaderMixin from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -54,7 +53,7 @@ """ -class StableDiffusionPipeline(DiffusionPipeline, LoraUNetLoaderMixin): +class StableDiffusionPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion. diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 9976235b75f6..2db7bb67bcbd 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -189,11 +189,9 @@ def step( or isinstance(timestep, torch.LongTensor) ): raise ValueError( - ( - "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" - " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" - " one of the `scheduler.timesteps` as a timestep." - ), + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep.", ) if not self.is_scale_input_called: diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 10f277f7e090..f1e9100acfe2 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -198,11 +198,9 @@ def step( or isinstance(timestep, torch.LongTensor) ): raise ValueError( - ( - "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" - " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" - " one of the `scheduler.timesteps` as a timestep." - ), + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep.", ) if not self.is_scale_input_called: diff --git a/tests/models/_ b/tests/models/_ new file mode 100644 index 000000000000..249ef454ba82 --- /dev/null +++ b/tests/models/_ @@ -0,0 +1,528 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, logging +from .cross_attention import AttnProcessor +from ..loaders import AttnProcsLoader +from .embeddings import TimestepEmbedding, Timesteps +from .modeling_utils import ModelMixin +from .unet_2d_blocks import ( + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + DownBlock2D, + UNetMidBlock2DCrossAttn, + UNetMidBlock2DSimpleCrossAttn, + UpBlock2D, + get_down_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor + + +class UNet2DConditionModel(ModelMixin, ConfigMixin, AttnProcsLoader): + r""" + UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep + and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately + summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: str = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attn_processor_cls=None, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + if attn_processor_cls is not None: + attn_processors_keys = self.attn_processors.keys() + self.set_attn_processor({k: attn_processor_cls() for k in attn_processors_keys}) + + @property + def attn_processors(self) -> Dict[str, AttnProcessor]: + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]): + if hasattr(module, "set_processor"): + processors[name] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]): + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(name)) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) diff --git a/tests/models/test_models_unet_2d.py b/tests/models/test_models_unet_2d.py index 803a9e88d0cb..39cd98a14726 100644 --- a/tests/models/test_models_unet_2d.py +++ b/tests/models/test_models_unet_2d.py @@ -21,13 +21,8 @@ import torch from diffusers import UNet2DModel -from diffusers.utils import ( - floats_tensor, - logging, - slow, - torch_all_close, - torch_device, -) +from diffusers.utils import floats_tensor, logging, slow, torch_all_close, torch_device + from ..test_modeling_common import ModelTesterMixin diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 3cf3095c1477..43220cdf73cb 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -14,11 +14,13 @@ # limitations under the License. import gc +import tempfile import unittest import torch from diffusers import UNet2DConditionModel +from diffusers.models.cross_attention import LoRACrossAttnProcessor from diffusers.utils import ( floats_tensor, load_hf_numpy, @@ -90,11 +92,6 @@ def test_xformers_enable_works(self): model.mid_block.attentions[0].transformer_blocks[0].attn1._use_memory_efficient_attention_xformers ), "xformers is not enabled" - def test_set_attention_processors(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS") def test_gradient_checkpointing(self): # enable deterministic behavior for gradient checkpointing @@ -273,6 +270,116 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma assert processor.is_run assert processor.number == 123 + def test_lora_processors(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + sample1 = model(**inputs_dict).sample + + lora_attn_procs = {} + for name in model.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1") else model.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = model.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(model.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = model.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRACrossAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + ) + + # add 1 to weights to mock trained weights + with torch.no_grad(): + lora_attn_procs[name].to_q_lora.lora_up.weight += 1 + lora_attn_procs[name].to_k_lora.lora_up.weight += 1 + lora_attn_procs[name].to_v_lora.lora_up.weight += 1 + lora_attn_procs[name].to_out_lora.lora_up.weight += 1 + + # make sure we can set a list of attention processors + model.set_attn_processor(lora_attn_procs) + model.to(torch_device) + + # test that attn processors can be set to itself + model.set_attn_processor(model.attn_processors) + + with torch.no_grad(): + sample2 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample + sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + sample4 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + assert (sample1 - sample2).abs().max() < 1e-4 + assert (sample3 - sample4).abs().max() < 1e-4 + + # sample 2 and sample 3 should be different + assert (sample2 - sample3).abs().max() > 1e-4 + + def test_lora_save_load(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + old_sample = model(**inputs_dict).sample + + lora_attn_procs = {} + for name in model.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1") else model.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = model.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(model.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = model.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRACrossAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + ) + lora_attn_procs[name] = lora_attn_procs[name].to(model.device) + + # add 1 to weights to mock trained weights + with torch.no_grad(): + lora_attn_procs[name].to_q_lora.lora_up.weight += 1 + lora_attn_procs[name].to_k_lora.lora_up.weight += 1 + lora_attn_procs[name].to_v_lora.lora_up.weight += 1 + lora_attn_procs[name].to_out_lora.lora_up.weight += 1 + + model.set_attn_processor(lora_attn_procs) + + with torch.no_grad(): + sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + new_model.load_attn_procs(tmpdirname) + + with torch.no_grad(): + new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + assert (sample - new_sample).abs().max() < 1e-4 + + # LoRA and no LoRA should NOT be the same + assert (sample - old_sample).abs().max() > 1e-4 + @slow class UNet2DConditionModelIntegrationTests(unittest.TestCase): diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 69831dee1bed..34770222d529 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -537,10 +537,8 @@ def test_scheduler_public_api(self): ) self.assertTrue( hasattr(scheduler, "scale_model_input"), - ( - f"{scheduler_class} does not implement a required class method `scale_model_input(sample," - " timestep)`" - ), + f"{scheduler_class} does not implement a required class method `scale_model_input(sample," + " timestep)`", ) self.assertTrue( hasattr(scheduler, "step"), From c5cf0a062648bb7f4fedaa70e4aa0dd4c8b7b365 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 16 Jan 2023 19:35:34 +0100 Subject: [PATCH 13/26] up --- tests/models/_ | 528 ------------------------------------------------- 1 file changed, 528 deletions(-) delete mode 100644 tests/models/_ diff --git a/tests/models/_ b/tests/models/_ deleted file mode 100644 index 249ef454ba82..000000000000 --- a/tests/models/_ +++ /dev/null @@ -1,528 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn as nn -import torch.utils.checkpoint - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, logging -from .cross_attention import AttnProcessor -from ..loaders import AttnProcsLoader -from .embeddings import TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin -from .unet_2d_blocks import ( - CrossAttnDownBlock2D, - CrossAttnUpBlock2D, - DownBlock2D, - UNetMidBlock2DCrossAttn, - UNetMidBlock2DSimpleCrossAttn, - UpBlock2D, - get_down_block, - get_up_block, -) - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@dataclass -class UNet2DConditionOutput(BaseOutput): - """ - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. - """ - - sample: torch.FloatTensor - - -class UNet2DConditionModel(ModelMixin, ConfigMixin, AttnProcsLoader): - r""" - UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep - and returns sample shaped output. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the models (such as downloading or saving, etc.) - - Parameters: - sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): - Height and width of input/output sample. - in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. - center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. - flip_sin_to_cos (`bool`, *optional*, defaults to `False`): - Whether to flip the sin to cos in the time embedding. - freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. - down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): - The tuple of downsample blocks to use. - mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): - The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): - The tuple of upsample blocks to use. - block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. - downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. - mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. - norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. - cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. - attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. - resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config - for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. - class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately - summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - sample_size: Optional[int] = None, - in_channels: int = 4, - out_channels: int = 4, - center_input_sample: bool = False, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ), - mid_block_type: str = "UNetMidBlock2DCrossAttn", - up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), - only_cross_attention: Union[bool, Tuple[bool]] = False, - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - layers_per_block: int = 2, - downsample_padding: int = 1, - mid_block_scale_factor: float = 1, - act_fn: str = "silu", - norm_num_groups: int = 32, - norm_eps: float = 1e-5, - cross_attention_dim: int = 1280, - attention_head_dim: Union[int, Tuple[int]] = 8, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - class_embed_type: Optional[str] = None, - num_class_embeds: Optional[int] = None, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - attn_processor_cls=None, - ): - super().__init__() - - self.sample_size = sample_size - time_embed_dim = block_out_channels[0] * 4 - - # input - self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) - - # time - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] - - self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - - # class embedding - if class_embed_type is None and num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - elif class_embed_type == "identity": - self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) - else: - self.class_embedding = None - - self.down_blocks = nn.ModuleList([]) - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - - if isinstance(only_cross_attention, bool): - only_cross_attention = [only_cross_attention] * len(down_block_types) - - if isinstance(attention_head_dim, int): - attention_head_dim = (attention_head_dim,) * len(down_block_types) - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[i], - downsample_padding=downsample_padding, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - self.down_blocks.append(down_block) - - # mid - if mid_block_type == "UNetMidBlock2DCrossAttn": - self.mid_block = UNetMidBlock2DCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[-1], - resnet_groups=norm_num_groups, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": - self.mid_block = UNetMidBlock2DSimpleCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[-1], - resnet_groups=norm_num_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - else: - raise ValueError(f"unknown mid_block_type : {mid_block_type}") - - # count how many layers upsample the images - self.num_upsamplers = 0 - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - reversed_attention_head_dim = list(reversed(attention_head_dim)) - only_cross_attention = list(reversed(only_cross_attention)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - is_final_block = i == len(block_out_channels) - 1 - - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] - - # add upsample block for all BUT final layer - if not is_final_block: - add_upsample = True - self.num_upsamplers += 1 - else: - add_upsample = False - - up_block = get_up_block( - up_block_type, - num_layers=layers_per_block + 1, - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=time_embed_dim, - add_upsample=add_upsample, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=reversed_attention_head_dim[i], - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) - self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) - - if attn_processor_cls is not None: - attn_processors_keys = self.attn_processors.keys() - self.set_attn_processor({k: attn_processor_cls() for k in attn_processors_keys}) - - @property - def attn_processors(self) -> Dict[str, AttnProcessor]: - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]): - if hasattr(module, "set_processor"): - processors[name] = module.processor - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]): - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(name)) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - def set_attention_slice(self, slice_size): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - sliceable_head_dims = [] - - def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): - if hasattr(module, "set_attention_slice"): - sliceable_head_dims.append(module.sliceable_head_dim) - - for child in module.children(): - fn_recursive_retrieve_slicable_dims(child) - - # retrieve number of attention layers - for module in self.children(): - fn_recursive_retrieve_slicable_dims(module) - - num_slicable_layers = len(sliceable_head_dims) - - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = [dim // 2 for dim in sliceable_head_dims] - elif slice_size == "max": - # make smallest slice possible - slice_size = num_slicable_layers * [1] - - slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size - - if len(slice_size) != len(sliceable_head_dims): - raise ValueError( - f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" - f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." - ) - - for i in range(len(slice_size)): - size = slice_size[i] - dim = sliceable_head_dims[i] - if size is not None and size > dim: - raise ValueError(f"size {size} has to be smaller or equal to {dim}.") - - # Recursively walk through all the children. - # Any children which exposes the set_attention_slice method - # gets the message - def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): - if hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size.pop()) - - for child in module.children(): - fn_recursive_set_attention_slice(child, slice_size) - - reversed_slice_size = list(reversed(slice_size)) - for module in self.children(): - fn_recursive_set_attention_slice(module, reversed_slice_size) - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): - module.gradient_checkpointing = value - - def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - class_labels: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ) -> Union[UNet2DConditionOutput, Tuple]: - r""" - Args: - sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps - encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. - - Returns: - [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - # By default samples have to be AT least a multiple of the overall upsampling factor. - # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). - # However, the upsampling interpolation output size can be forced to fit any upsampling size - # on the fly if necessary. - default_overall_up_factor = 2**self.num_upsamplers - - # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` - forward_upsample_size = False - upsample_size = None - - if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - logger.info("Forward upsample size to force interpolation output size.") - forward_upsample_size = True - - # prepare attention_mask - if attention_mask is not None: - attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # 0. center input if necessary - if self.config.center_input_sample: - sample = 2 * sample - 1.0 - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) - emb = self.time_embedding(t_emb) - - if self.class_embedding is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - - if self.config.class_embed_type == "timestep": - class_labels = self.time_proj(class_labels) - - class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) - emb = emb + class_emb - - # 2. pre-process - sample = self.conv_in(sample) - - # 3. down - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - - down_block_res_samples += res_samples - - # 4. mid - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - - # 5. up - for i, upsample_block in enumerate(self.up_blocks): - is_final_block = i == len(self.up_blocks) - 1 - - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - - # if we have not reached the final block and need to forward the - # upsample size, we do it here - if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] - - if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - upsample_size=upsample_size, - attention_mask=attention_mask, - ) - else: - sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size - ) - # 6. post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - if not return_dict: - return (sample,) - - return UNet2DConditionOutput(sample=sample) From 060697e7c2428800dcb3050667b995db8b9a88c3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 16 Jan 2023 19:41:29 +0100 Subject: [PATCH 14/26] finish more --- examples/lora/train_lora.py | 39 +++---------------- .../textual_inversion_bf16.py | 5 ++- .../textual_inversion/textual_inversion.py | 5 ++- .../textual_inversion_flax.py | 5 ++- src/diffusers/models/attention.py | 6 ++- src/diffusers/models/cross_attention.py | 6 ++- .../scheduling_euler_ancestral_discrete.py | 8 ++-- .../schedulers/scheduling_euler_discrete.py | 8 ++-- tests/test_scheduler.py | 6 ++- 9 files changed, 39 insertions(+), 49 deletions(-) diff --git a/examples/lora/train_lora.py b/examples/lora/train_lora.py index a5d40831ba34..cf3ef9150acf 100644 --- a/examples/lora/train_lora.py +++ b/examples/lora/train_lora.py @@ -5,7 +5,7 @@ import os import warnings from pathlib import Path -from typing import Dict, Optional +from typing import Optional import torch import torch.nn.functional as F @@ -20,6 +20,7 @@ from accelerate.logging import get_logger from accelerate.utils import set_seed from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers.loaders import AttnProcsLayers from diffusers.models.cross_attention import LoRACrossAttnProcessor from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version @@ -34,38 +35,8 @@ wandb.login() -class LoraLayers(torch.nn.Module): - def __init__(self, state_dict: Dict[str, torch.Tensor]): - super().__init__() - self.layers = torch.nn.ModuleList(state_dict.values()) - self.mapping = {k: v for k, v in enumerate(state_dict.keys())} - self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} - - # we add a hook to state_dict() and load_state_dict() so that the - # naming fits with `unet.attn_processors` - def map_to(module, state_dict, *args, **kwargs): - new_state_dict = {} - for key, value in state_dict.items(): - num = int(key.split(".")[1]) # 0 is always "layers" - new_key = key.replace(f"layers.{num}", module.mapping[num]) - new_state_dict[new_key] = value - - return new_state_dict - - def map_from(module, state_dict, *args, **kwargs): - all_keys = list(state_dict.keys()) - for key in all_keys: - replace_key = ".".join(key.split(".")[:-3]) - new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") - state_dict[new_key] = state_dict[key] - del state_dict[key] - - self._register_state_dict_hook(map_to) - self._register_load_state_dict_pre_hook(map_from, with_module=True) - - # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.10.0.dev0") +check_min_version("0.12.0.dev0") logger = get_logger(__name__) @@ -638,7 +609,7 @@ def main(args): ) unet.set_attn_processor(lora_attn_procs) - lora_layers = LoraLayers(unet.attn_processors) + lora_layers = AttnProcsLayers(unet.attn_processors) state_dict = lora_layers.state_dict() lora_layers.load_state_dict(state_dict) @@ -898,7 +869,7 @@ def main(args): # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: - torch.save(lora_layers.state_dict(), os.path.join(args.output_dir, "lora_layers.bin")) + unet.save_attn_procs(args.output_dir) if args.push_to_hub: repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) diff --git a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py index 4a7540aa161b..a9b663b2e68c 100644 --- a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py +++ b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py @@ -336,7 +336,10 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - (h, w,) = ( + ( + h, + w, + ) = ( img.shape[0], img.shape[1], ) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index a9f766ac79c0..11e145e63a3f 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -381,7 +381,10 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - (h, w,) = ( + ( + h, + w, + ) = ( img.shape[0], img.shape[1], ) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index 2fb961f9a3f9..1001126402ec 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -306,7 +306,10 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - (h, w,) = ( + ( + h, + w, + ) = ( img.shape[0], img.shape[1], ) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index ffe67987467b..85dcc800fd1e 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -90,8 +90,10 @@ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_atten if use_memory_efficient_attention_xformers: if not is_xformers_available(): raise ModuleNotFoundError( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - " xformers", + ( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers" + ), name="xformers", ) elif not torch.cuda.is_available(): diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 82363fc1d287..2a95865a2a1c 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -105,8 +105,10 @@ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_atten ) elif not is_xformers_available(): raise ModuleNotFoundError( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - " xformers", + ( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers" + ), name="xformers", ) elif not torch.cuda.is_available(): diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 2db7bb67bcbd..9976235b75f6 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -189,9 +189,11 @@ def step( or isinstance(timestep, torch.LongTensor) ): raise ValueError( - "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" - " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" - " one of the `scheduler.timesteps` as a timestep.", + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), ) if not self.is_scale_input_called: diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index f1e9100acfe2..10f277f7e090 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -198,9 +198,11 @@ def step( or isinstance(timestep, torch.LongTensor) ): raise ValueError( - "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" - " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" - " one of the `scheduler.timesteps` as a timestep.", + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), ) if not self.is_scale_input_called: diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 34770222d529..69831dee1bed 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -537,8 +537,10 @@ def test_scheduler_public_api(self): ) self.assertTrue( hasattr(scheduler, "scale_model_input"), - f"{scheduler_class} does not implement a required class method `scale_model_input(sample," - " timestep)`", + ( + f"{scheduler_class} does not implement a required class method `scale_model_input(sample," + " timestep)`" + ), ) self.assertTrue( hasattr(scheduler, "step"), From 5d5fd77edbbce6d100e0fd3a594dc35f1c083eb9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 17 Jan 2023 11:53:17 +0000 Subject: [PATCH 15/26] finish more --- examples/dreambooth/README.md | 92 +++++ .../train_dreambooth_lora.py} | 106 ++++-- examples/lora/README.md | 323 ------------------ src/diffusers/loaders.py | 48 ++- src/diffusers/models/unet_2d_condition.py | 4 +- 5 files changed, 193 insertions(+), 380 deletions(-) rename examples/{lora/train_lora.py => dreambooth/train_dreambooth_lora.py} (91%) delete mode 100644 examples/lora/README.md diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index 28ac0b4f6a96..621dac8a681c 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -5,6 +5,7 @@ The `train_dreambooth.py` script shows how to implement the training procedure a ## Running locally with PyTorch + ### Installing the dependencies Before running the scripts, make sure to install the library's training dependencies: @@ -235,6 +236,97 @@ image.save("dog-bucket.png") You can also perform inference from one of the checkpoints saved during the training process, if you used the `--checkpointing_steps` argument. Please, refer to [the documentation](https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint) to see how to do it. +## Training with Low-Rank Adaptation of Large Language Models (LoRA) + +Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen* + +In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages: +- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114) +- Rank-decomposition matrices have significantly fewer parameters than orginal model which means that trained LoRA weights are easily portable. +- LoRA attention layers allow to control to which extend the model is adapted torwards new training images via a `scale` parameter. + +### Training + +Let's get started with a simple example. We will re-use the dog example of the [previous section](#dog-toy-example). + +First, you need to set-up your dreambooth training example as is explained in the [installation section](#Installing-the-dependencies). +Next, let's download the toy dog dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. Make sure to set `INSTANCE_DIR` to the name of your directly further below. This will be our training data. + +Now, you can launch the training. Here we will use [Stable Diffusion 1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5). + +**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** + +**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [wandb](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training to automatically log images.___** + + +```bash +export MODEL_NAME="runwayml/stable-diffusion-v1-5" +export INSTANCE_DIR="path-to-instance-images" +export OUTPUT_DIR="path-to-save-model" +``` + +For this example we want to directly store the trained LoRA embeddings on the Hub, so +we need to be logged in and add the `--push_to_hub` flag. + +```bash +huggingface-cli login +``` + +Now we can start training! + +```bash +accelerate launch train_dreambooth_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a photo of sks dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=1 \ + --checkpointing_steps=100 \ + --learning_rate=1e-4 \ + --report_to="wandb" \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=500 \ + --validation_prompt="A photo of sks dog in a bucket" \ + --seed="0" \ + --push_to_hub +``` + +**___Note: When using LoRA we can use a much higher learning rate compared to vanilla dreambooth. Here we +use *1e-4* instead of the usual *2e-6*.___** + +The final LoRA embedding weights have been uploaded to [patrickvonplaten/lora](https://huggingface.co/patrickvonplaten/lora). **___Note: [The final weights](https://huggingface.co/patrickvonplaten/lora/blob/main/pytorch_attn_procs.bin) are only 3 MB in size which is orders of magnitudes smaller than the original model.** + +and the training results are summarized [here](https://wandb.ai/patrickvonplaten/dreambooth/reports/LoRA-DreamBooth-Dog-Example--VmlldzozMzUzMTcx?accessToken=9drrltpimid0jk8q50p91vwovde24cnimc30g3bjd3i5wys5twi7uczd7jdh85dh) + +### Inference + +After training, LoRA weights can very easily loaded into the original pipeline. First, you need to +load the original pipeline: + +```python +from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler +import torch + +pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) +pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) +pipe.to("cuda") +``` + +Next, we can load the adapter layers into the UNet with the [`load_attn_procs` function](TODO:). + +```python +pipe.load_attn_procs("patrickvonplaten/lora") +``` + +Finally, we can run the model in inference. + +```python +image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0] +``` + ## Training with Flax/JAX For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script. diff --git a/examples/lora/train_lora.py b/examples/dreambooth/train_dreambooth_lora.py similarity index 91% rename from examples/lora/train_lora.py rename to examples/dreambooth/train_dreambooth_lora.py index cf3ef9150acf..ce26df338a25 100644 --- a/examples/lora/train_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -19,29 +19,30 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed -from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + DPMSolverMultistepScheduler, + UNet2DConditionModel, +) from diffusers.loaders import AttnProcsLayers from diffusers.models.cross_attention import LoRACrossAttnProcessor from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version from diffusers.utils.import_utils import is_xformers_available -from huggingface_hub import HfFolder, Repository, whoami +from huggingface_hub import HfFolder, Repository, create_repo, whoami from PIL import Image from torchvision import transforms from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig -wandb.login() - - # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.12.0.dev0") logger = get_logger(__name__) -run = wandb.init(project="stable_diffusion_lora") - generated_table = wandb.Table(columns=["gen_num", "prompt", "generated_images"]) @@ -115,7 +116,25 @@ def parse_args(input_args=None): help="The prompt to specify images in the same class as provided instance images.", ) parser.add_argument( - "--save_sample_prompt", type=str, default=None, help="A prompt that is sampled during training." + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=50, + help=( + "Run dreambooth validation every X updates. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), ) parser.add_argument( "--with_prior_preservation", @@ -534,6 +553,8 @@ def main(args): repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) else: repo_name = args.hub_model_id + + repo_name = create_repo(repo_name, exists_ok=True) repo = Repository(args.output_dir, clone_from=repo_name) with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: @@ -594,7 +615,7 @@ def main(args): # Set correct lora layers lora_attn_procs = {} for name in unet.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1") else unet.config.cross_attention_dim + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): @@ -835,45 +856,76 @@ def main(args): if global_step >= args.max_train_steps: break - if args.save_sample_prompt is not None and epoch % 10 == 0: - print("Running inference...") + if args.validation_prompt is not None and epoch % 10 == 0: + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), text_encoder=accelerator.unwrap_model(text_encoder), revision=args.revision, ) - pipeline.save_pretrained(args.output_dir) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) pipeline = pipeline.to(accelerator.device) - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) pipeline.set_progress_bar_config(disable=True) - sample_dir = "/home/patrick_huggingface_co/lora-tryout/samples" - os.makedirs(sample_dir, exist_ok=True) - - for i in tqdm(range(5), desc="Generating samples"): - prompt = args.save_sample_prompt - images = pipeline(prompt, num_inference_steps=30, generator=generator).images - image = images[0] - image.save(os.path.join(sample_dir, f"{i}.png")) - - global_step = epoch * len(train_dataloader) + i - generated_table.add_data(global_step, prompt, wandb.Image(image)) - run.log({"generated_image": wandb.Image(image)}) + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + prompt = args.num_validation_images * [args.validation_prompt] + images = pipeline(prompt, num_inference_steps=25, generator=generator).images + + for tracker in accelerator.trackers: + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) del pipeline torch.cuda.empty_cache() - run.log({"generated_table": generated_table}) - # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: + unet = unet.to(torch.float32) unet.save_attn_procs(args.output_dir) if args.push_to_hub: repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + # Final inference + # Load previous pipeline + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=torch.float16 + ) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + + # load attention processors + pipeline.unet.load_attn_procs(args.output_dir) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + prompt = args.num_validation_images * [args.validation_prompt] + images = pipeline(prompt, num_inference_steps=25, generator=generator).images + + for tracker in accelerator.trackers: + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + accelerator.end_training() diff --git a/examples/lora/README.md b/examples/lora/README.md deleted file mode 100644 index 2858c04c48b0..000000000000 --- a/examples/lora/README.md +++ /dev/null @@ -1,323 +0,0 @@ -# DreamBooth training example - -[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. -The `train_dreambooth.py` script shows how to implement the training procedure and adapt it for stable diffusion. - - -## Running locally with PyTorch -### Installing the dependencies - -Before running the scripts, make sure to install the library's training dependencies: - -**Important** - -To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: -```bash -git clone https://github.com/huggingface/diffusers -cd diffusers -pip install -e . -``` - -Then cd in the example folder and run -```bash -pip install -r requirements.txt -``` - -And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: - -```bash -accelerate config -``` - -Or for a default accelerate configuration without answering questions about your environment - -```bash -accelerate config default -``` - -Or if your environment doesn't support an interactive shell e.g. a notebook - -```python -from accelerate.utils import write_basic_config -write_basic_config() -``` - -### Dog toy example - -Now let's get our dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. This will be our training data. - -And launch the training using - -**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** - -```bash -export MODEL_NAME="CompVis/stable-diffusion-v1-4" -export INSTANCE_DIR="path-to-instance-images" -export OUTPUT_DIR="path-to-save-model" - -accelerate launch train_dreambooth.py \ - --pretrained_model_name_or_path=$MODEL_NAME \ - --instance_data_dir=$INSTANCE_DIR \ - --output_dir=$OUTPUT_DIR \ - --instance_prompt="a photo of sks dog" \ - --resolution=512 \ - --train_batch_size=1 \ - --gradient_accumulation_steps=1 \ - --learning_rate=5e-6 \ - --lr_scheduler="constant" \ - --lr_warmup_steps=0 \ - --max_train_steps=400 -``` - -### Training with prior-preservation loss - -Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data. -According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time. - -```bash -export MODEL_NAME="CompVis/stable-diffusion-v1-4" -export INSTANCE_DIR="path-to-instance-images" -export CLASS_DIR="path-to-class-images" -export OUTPUT_DIR="path-to-save-model" - -accelerate launch train_dreambooth.py \ - --pretrained_model_name_or_path=$MODEL_NAME \ - --instance_data_dir=$INSTANCE_DIR \ - --class_data_dir=$CLASS_DIR \ - --output_dir=$OUTPUT_DIR \ - --with_prior_preservation --prior_loss_weight=1.0 \ - --instance_prompt="a photo of sks dog" \ - --class_prompt="a photo of dog" \ - --resolution=512 \ - --train_batch_size=1 \ - --gradient_accumulation_steps=1 \ - --learning_rate=5e-6 \ - --lr_scheduler="constant" \ - --lr_warmup_steps=0 \ - --num_class_images=200 \ - --max_train_steps=800 -``` - - -### Training on a 16GB GPU: - -With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU. - -To install `bitandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation). - -```bash -export MODEL_NAME="CompVis/stable-diffusion-v1-4" -export INSTANCE_DIR="path-to-instance-images" -export CLASS_DIR="path-to-class-images" -export OUTPUT_DIR="path-to-save-model" - -accelerate launch train_dreambooth.py \ - --pretrained_model_name_or_path=$MODEL_NAME \ - --instance_data_dir=$INSTANCE_DIR \ - --class_data_dir=$CLASS_DIR \ - --output_dir=$OUTPUT_DIR \ - --with_prior_preservation --prior_loss_weight=1.0 \ - --instance_prompt="a photo of sks dog" \ - --class_prompt="a photo of dog" \ - --resolution=512 \ - --train_batch_size=1 \ - --gradient_accumulation_steps=2 --gradient_checkpointing \ - --use_8bit_adam \ - --learning_rate=5e-6 \ - --lr_scheduler="constant" \ - --lr_warmup_steps=0 \ - --num_class_images=200 \ - --max_train_steps=800 -``` - -### Training on a 8 GB GPU: - -By using [DeepSpeed](https://www.deepspeed.ai/) it's possible to offload some -tensors from VRAM to either CPU or NVME allowing to train with less VRAM. - -DeepSpeed needs to be enabled with `accelerate config`. During configuration -answer yes to "Do you want to use DeepSpeed?". With DeepSpeed stage 2, fp16 -mixed precision and offloading both parameters and optimizer state to cpu it's -possible to train on under 8 GB VRAM with a drawback of requiring significantly -more RAM (about 25 GB). See [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more DeepSpeed configuration options. - -Changing the default Adam optimizer to DeepSpeed's special version of Adam -`deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but enabling -it requires CUDA toolchain with the same version as pytorch. 8-bit optimizer -does not seem to be compatible with DeepSpeed at the moment. - -```bash -export MODEL_NAME="CompVis/stable-diffusion-v1-4" -export INSTANCE_DIR="path-to-instance-images" -export CLASS_DIR="path-to-class-images" -export OUTPUT_DIR="path-to-save-model" - -accelerate launch --mixed_precision="fp16" train_dreambooth.py \ - --pretrained_model_name_or_path=$MODEL_NAME \ - --instance_data_dir=$INSTANCE_DIR \ - --class_data_dir=$CLASS_DIR \ - --output_dir=$OUTPUT_DIR \ - --with_prior_preservation --prior_loss_weight=1.0 \ - --instance_prompt="a photo of sks dog" \ - --class_prompt="a photo of dog" \ - --resolution=512 \ - --train_batch_size=1 \ - --sample_batch_size=1 \ - --gradient_accumulation_steps=1 --gradient_checkpointing \ - --learning_rate=5e-6 \ - --lr_scheduler="constant" \ - --lr_warmup_steps=0 \ - --num_class_images=200 \ - --max_train_steps=800 -``` - -### Fine-tune text encoder with the UNet. - -The script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces. -Pass the `--train_text_encoder` argument to the script to enable training `text_encoder`. - -___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___ - -```bash -export MODEL_NAME="CompVis/stable-diffusion-v1-4" -export INSTANCE_DIR="path-to-instance-images" -export CLASS_DIR="path-to-class-images" -export OUTPUT_DIR="path-to-save-model" - -accelerate launch train_dreambooth.py \ - --pretrained_model_name_or_path=$MODEL_NAME \ - --train_text_encoder \ - --instance_data_dir=$INSTANCE_DIR \ - --class_data_dir=$CLASS_DIR \ - --output_dir=$OUTPUT_DIR \ - --with_prior_preservation --prior_loss_weight=1.0 \ - --instance_prompt="a photo of sks dog" \ - --class_prompt="a photo of dog" \ - --resolution=512 \ - --train_batch_size=1 \ - --use_8bit_adam \ - --gradient_checkpointing \ - --learning_rate=2e-6 \ - --lr_scheduler="constant" \ - --lr_warmup_steps=0 \ - --num_class_images=200 \ - --max_train_steps=800 -``` - -### Using DreamBooth for other pipelines than Stable Diffusion - -Altdiffusion also support dreambooth now, the runing comman is basically the same as abouve, all you need to do is replace the `MODEL_NAME` like this: -One can now simply change the `pretrained_model_name_or_path` to another architecture such as [`AltDiffusion`](https://huggingface.co/docs/diffusers/api/pipelines/alt_diffusion). - -``` -export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion-m9" -or -export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion" -``` - -### Inference - -Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt. - -```python -from diffusers import StableDiffusionPipeline -import torch - -model_id = "path-to-your-trained-model" -pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") - -prompt = "A photo of sks dog in a bucket" -image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] - -image.save("dog-bucket.png") -``` - -### Inference from a training checkpoint - -You can also perform inference from one of the checkpoints saved during the training process, if you used the `--checkpointing_steps` argument. Please, refer to [the documentation](https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint) to see how to do it. - -## Training with Flax/JAX - -For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script. - -____Note: The flax example don't yet support features like gradient checkpoint, gradient accumulation etc, so to use flax for faster training we will need >30GB cards.___ - - -Before running the scripts, make sure to install the library's training dependencies: - -```bash -pip install -U -r requirements_flax.txt -``` - - -### Training without prior preservation loss - -```bash -export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" -export INSTANCE_DIR="path-to-instance-images" -export OUTPUT_DIR="path-to-save-model" - -python train_dreambooth_flax.py \ - --pretrained_model_name_or_path=$MODEL_NAME \ - --instance_data_dir=$INSTANCE_DIR \ - --output_dir=$OUTPUT_DIR \ - --instance_prompt="a photo of sks dog" \ - --resolution=512 \ - --train_batch_size=1 \ - --learning_rate=5e-6 \ - --max_train_steps=400 -``` - - -### Training with prior preservation loss - -```bash -export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" -export INSTANCE_DIR="path-to-instance-images" -export CLASS_DIR="path-to-class-images" -export OUTPUT_DIR="path-to-save-model" - -python train_dreambooth_flax.py \ - --pretrained_model_name_or_path=$MODEL_NAME \ - --instance_data_dir=$INSTANCE_DIR \ - --class_data_dir=$CLASS_DIR \ - --output_dir=$OUTPUT_DIR \ - --with_prior_preservation --prior_loss_weight=1.0 \ - --instance_prompt="a photo of sks dog" \ - --class_prompt="a photo of dog" \ - --resolution=512 \ - --train_batch_size=1 \ - --learning_rate=5e-6 \ - --num_class_images=200 \ - --max_train_steps=800 -``` - - -### Fine-tune text encoder with the UNet. - -```bash -export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" -export INSTANCE_DIR="path-to-instance-images" -export CLASS_DIR="path-to-class-images" -export OUTPUT_DIR="path-to-save-model" - -python train_dreambooth_flax.py \ - --pretrained_model_name_or_path=$MODEL_NAME \ - --train_text_encoder \ - --instance_data_dir=$INSTANCE_DIR \ - --class_data_dir=$CLASS_DIR \ - --output_dir=$OUTPUT_DIR \ - --with_prior_preservation --prior_loss_weight=1.0 \ - --instance_prompt="a photo of sks dog" \ - --class_prompt="a photo of dog" \ - --resolution=512 \ - --train_batch_size=1 \ - --learning_rate=2e-6 \ - --num_class_images=200 \ - --max_train_steps=800 -``` - -### Training with xformers: -You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation. - -You can also use Dreambooth to train the specialized in-painting model. See [the script in the research folder for details](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/dreambooth_inpaint). diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index eb3d1f46982f..3690e525ba73 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -49,7 +49,7 @@ def map_to(module, state_dict, *args, **kwargs): def map_from(module, state_dict, *args, **kwargs): all_keys = list(state_dict.keys()) for key in all_keys: - replace_key = ".".join(key.split(".processor.")[1:]) + replace_key = key.split(".processor")[0] + ".processor" new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") state_dict[new_key] = state_dict[key] del state_dict[key] @@ -59,7 +59,7 @@ def map_from(module, state_dict, *args, **kwargs): class AttnProcsLoader: - def load_attn_procs(self, pretrained_model_name_or_path, **kwargs): + def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) @@ -67,7 +67,6 @@ def load_attn_procs(self, pretrained_model_name_or_path, **kwargs): local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) - torch_dtype = kwargs.pop("torch_dtype", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", ATTN_WEIGHT_NAME) @@ -76,26 +75,23 @@ def load_attn_procs(self, pretrained_model_name_or_path, **kwargs): "framework": "pytorch", } - if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): - raise ValueError( - f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, ) - - model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=weight_name, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - - state_dict = torch.load(model_file, map_location="cpu") + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path_or_dict # fill attn processors attn_processors = {} @@ -121,12 +117,8 @@ def load_attn_procs(self, pretrained_model_name_or_path, **kwargs): else: raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.") - # dtype - if torch_dtype is not None: - attn_processors = {k: v.to(torch_dtype) for k, v in attn_processors.items()} - - # device - attn_processors = {k: v.to(self.device) for k, v in attn_processors.items()} + # set correct dtype & device + attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()} # set layers self.set_attn_processor(attn_processors) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index aa1225f527d6..89b1757a3d7d 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -274,7 +274,7 @@ def attn_processors(self) -> Dict[str, AttnProcessor]: def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]): if hasattr(module, "set_processor"): - processors[name] = module.processor + processors[f"{name}.processor"] = module.processor for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) @@ -300,7 +300,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if not isinstance(processor, dict): module.set_processor(processor) else: - module.set_processor(processor.pop(name)) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) From b7478efd56e544698df4250d5ebf4f9d34743652 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 17 Jan 2023 13:52:47 +0000 Subject: [PATCH 16/26] up --- docs/source/en/api/loaders.mdx | 30 +++++++ examples/dreambooth/train_dreambooth_lora.py | 2 +- src/diffusers/loaders.py | 83 +++++++++++++++++-- src/diffusers/models/cross_attention.py | 12 +-- src/diffusers/models/unet_2d_condition.py | 4 +- tests/models/test_models_unet_2d_condition.py | 20 ++--- 6 files changed, 123 insertions(+), 28 deletions(-) create mode 100644 docs/source/en/api/loaders.mdx diff --git a/docs/source/en/api/loaders.mdx b/docs/source/en/api/loaders.mdx new file mode 100644 index 000000000000..b8fb4c4db495 --- /dev/null +++ b/docs/source/en/api/loaders.mdx @@ -0,0 +1,30 @@ + + +# Loaders + +There are many weights to train adapter neural networks for diffusion models, such as +- [Textual Inversion](./training/text_inversion.mdx) +- [LoRA](https://github.com/cloneofsimo/lora) +- [Hypernetworks](https://arxiv.org/abs/1609.09106) + +Such adapter neural networks often only conists of a fraction of the number of weights compared +to the pretrained model and as such are very portable. The Diffusers library offers an easy-to-use +API to load such adapter neural networks via the [`loaders.py` module](https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders.py). + +**Note**: This module is still highly experimental and prone to future changes. + +## LoaderMixins + +### UNet2DConditionLoadersMixin + +[[autodoc]] loaders.UNet2DConditionLoadersMixin diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index ce26df338a25..1a0666d9378c 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -554,7 +554,7 @@ def main(args): else: repo_name = args.hub_model_id - repo_name = create_repo(repo_name, exists_ok=True) + repo_name = create_repo(repo_name, exist_ok=True) repo = Repository(args.output_dir, clone_from=repo_name) with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 3690e525ba73..e60f2abcfd10 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -58,8 +58,75 @@ def map_from(module, state_dict, *args, **kwargs): self._register_load_state_dict_pre_hook(map_from, with_module=True) -class AttnProcsLoader: +class UNet2DConditionLoadersMixin: def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + r""" + Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be defined in [cross_attention.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py) and be a `torch.nn.Module` class. + + + + This function is experimental and might change in the future + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. + - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) @@ -105,9 +172,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict lora_grouped_dict[attn_processor_key][sub_key] = value for key, value_dict in lora_grouped_dict.items(): - rank = value_dict["to_k_lora.lora_down.weight"].shape[0] - cross_attention_dim = value_dict["to_k_lora.lora_down.weight"].shape[1] - hidden_size = value_dict["to_k_lora.lora_up.weight"].shape[0] + rank = value_dict["to_k_lora.down.weight"].shape[0] + cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] + hidden_size = value_dict["to_k_lora.up.weight"].shape[0] attn_processors[key] = LoRACrossAttnProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank @@ -130,9 +197,9 @@ def save_attn_procs( weights_name: str = ATTN_WEIGHT_NAME, save_function: Callable = None, ): - """ - Save a model and its configuration file to a directory, so that it can be re-loaded using the - `[`~models.ModelMixin.from_pretrained`]` class method. + r""" + Save an attention procesor to a directory, so that it can be re-loaded using the + `[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`]` method. Arguments: save_directory (`str` or `os.PathLike`): @@ -145,8 +212,6 @@ def save_attn_procs( The function to use to save the state dictionary. Useful on distributed training like TPUs when one need to replace `torch.save` by another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `False`): - Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). """ if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 2a95865a2a1c..cc57762188a9 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -253,16 +253,16 @@ def __init__(self, in_features, out_features, rank=4): if rank > min(in_features, out_features): raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") - self.lora_down = nn.Linear(in_features, rank, bias=False) - self.lora_up = nn.Linear(rank, out_features, bias=False) + self.down = nn.Linear(in_features, rank, bias=False) + self.up = nn.Linear(rank, out_features, bias=False) self.scale = 1.0 - nn.init.normal_(self.lora_down.weight, std=1 / rank) - nn.init.zeros_(self.lora_up.weight) + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) def forward(self, hidden_states): - down_hidden_states = self.lora_down(hidden_states) - up_hidden_states = self.lora_up(down_hidden_states) + down_hidden_states = self.down(hidden_states) + up_hidden_states = self.up(down_hidden_states) return up_hidden_states diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 89b1757a3d7d..e4d6e123ebff 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -19,7 +19,7 @@ import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import AttnProcsLoader +from ..loaders import UNet2DConditionLoadersMixin from ..utils import BaseOutput, logging from .cross_attention import AttnProcessor from .embeddings import TimestepEmbedding, Timesteps @@ -50,7 +50,7 @@ class UNet2DConditionOutput(BaseOutput): sample: torch.FloatTensor -class UNet2DConditionModel(ModelMixin, ConfigMixin, AttnProcsLoader): +class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep and returns sample shaped output. diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 43220cdf73cb..f8f29e9f232d 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -284,7 +284,7 @@ def test_lora_processors(self): lora_attn_procs = {} for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1") else model.config.cross_attention_dim + cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = model.config.block_out_channels[-1] elif name.startswith("up_blocks"): @@ -300,10 +300,10 @@ def test_lora_processors(self): # add 1 to weights to mock trained weights with torch.no_grad(): - lora_attn_procs[name].to_q_lora.lora_up.weight += 1 - lora_attn_procs[name].to_k_lora.lora_up.weight += 1 - lora_attn_procs[name].to_v_lora.lora_up.weight += 1 - lora_attn_procs[name].to_out_lora.lora_up.weight += 1 + lora_attn_procs[name].to_q_lora.up.weight += 1 + lora_attn_procs[name].to_k_lora.up.weight += 1 + lora_attn_procs[name].to_v_lora.up.weight += 1 + lora_attn_procs[name].to_out_lora.up.weight += 1 # make sure we can set a list of attention processors model.set_attn_processor(lora_attn_procs) @@ -338,7 +338,7 @@ def test_lora_save_load(self): lora_attn_procs = {} for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1") else model.config.cross_attention_dim + cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = model.config.block_out_channels[-1] elif name.startswith("up_blocks"): @@ -355,10 +355,10 @@ def test_lora_save_load(self): # add 1 to weights to mock trained weights with torch.no_grad(): - lora_attn_procs[name].to_q_lora.lora_up.weight += 1 - lora_attn_procs[name].to_k_lora.lora_up.weight += 1 - lora_attn_procs[name].to_v_lora.lora_up.weight += 1 - lora_attn_procs[name].to_out_lora.lora_up.weight += 1 + lora_attn_procs[name].to_q_lora.up.weight += 1 + lora_attn_procs[name].to_k_lora.up.weight += 1 + lora_attn_procs[name].to_v_lora.up.weight += 1 + lora_attn_procs[name].to_out_lora.up.weight += 1 model.set_attn_processor(lora_attn_procs) From 1530b76088dd4b5e0d96b268df3040b5bd6b7a57 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 17 Jan 2023 15:45:30 +0000 Subject: [PATCH 17/26] up --- examples/dreambooth/README.md | 3 +++ src/diffusers/models/unet_2d_condition.py | 10 ++++++++++ 2 files changed, 13 insertions(+) diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index 621dac8a681c..17f692aaf728 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -245,6 +245,9 @@ In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-de - Rank-decomposition matrices have significantly fewer parameters than orginal model which means that trained LoRA weights are easily portable. - LoRA attention layers allow to control to which extend the model is adapted torwards new training images via a `scale` parameter. +[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in +the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository. + ### Training Let's get started with a simple example. We will re-use the dog example of the [previous section](#dog-toy-example). diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index e4d6e123ebff..e3197fe5d4c7 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -269,6 +269,10 @@ def __init__( @property def attn_processors(self) -> Dict[str, AttnProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. + """ # set recursively processors = {} @@ -287,6 +291,12 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]): + r""" + Parameters: + `processor (`dict` of `AttnProcessor` or `AttnProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor of **all** `CrossAttention` layers. + In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors. + + """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: From 17850dee3b3c5e9d09727dc64ff351759b70e32f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 17 Jan 2023 15:47:27 +0000 Subject: [PATCH 18/26] change year --- CONTRIBUTING.md | 2 +- docs/source/en/api/configuration.mdx | 2 +- docs/source/en/api/diffusion_pipeline.mdx | 2 +- docs/source/en/api/experimental/rl.mdx | 2 +- docs/source/en/api/loaders.mdx | 2 +- docs/source/en/api/logging.mdx | 2 +- docs/source/en/api/models.mdx | 2 +- docs/source/en/api/outputs.mdx | 2 +- .../source/en/api/pipelines/alt_diffusion.mdx | 2 +- .../en/api/pipelines/audio_diffusion.mdx | 2 +- .../en/api/pipelines/cycle_diffusion.mdx | 2 +- .../en/api/pipelines/dance_diffusion.mdx | 2 +- docs/source/en/api/pipelines/ddim.mdx | 2 +- docs/source/en/api/pipelines/ddpm.mdx | 2 +- .../en/api/pipelines/latent_diffusion.mdx | 2 +- .../api/pipelines/latent_diffusion_uncond.mdx | 2 +- docs/source/en/api/pipelines/overview.mdx | 2 +- .../en/api/pipelines/paint_by_example.mdx | 2 +- docs/source/en/api/pipelines/pndm.mdx | 2 +- docs/source/en/api/pipelines/repaint.mdx | 2 +- docs/source/en/api/pipelines/score_sde_ve.mdx | 2 +- .../pipelines/stable_diffusion/depth2img.mdx | 2 +- .../stable_diffusion/image_variation.mdx | 2 +- .../pipelines/stable_diffusion/img2img.mdx | 2 +- .../pipelines/stable_diffusion/inpaint.mdx | 2 +- .../pipelines/stable_diffusion/overview.mdx | 2 +- .../pipelines/stable_diffusion/text2img.mdx | 2 +- .../pipelines/stable_diffusion/upscale.mdx | 2 +- .../en/api/pipelines/stable_diffusion_2.mdx | 2 +- .../api/pipelines/stable_diffusion_safe.mdx | 2 +- .../en/api/pipelines/stochastic_karras_ve.mdx | 2 +- docs/source/en/api/pipelines/unclip.mdx | 2 +- .../en/api/pipelines/versatile_diffusion.mdx | 2 +- docs/source/en/api/pipelines/vq_diffusion.mdx | 2 +- docs/source/en/api/schedulers/ddim.mdx | 2 +- docs/source/en/api/schedulers/ddpm.mdx | 2 +- docs/source/en/api/schedulers/deis.mdx | 2 +- .../source/en/api/schedulers/dpm_discrete.mdx | 2 +- .../api/schedulers/dpm_discrete_ancestral.mdx | 2 +- docs/source/en/api/schedulers/euler.mdx | 2 +- .../en/api/schedulers/euler_ancestral.mdx | 2 +- docs/source/en/api/schedulers/heun.mdx | 2 +- docs/source/en/api/schedulers/ipndm.mdx | 2 +- .../source/en/api/schedulers/lms_discrete.mdx | 2 +- .../api/schedulers/multistep_dpm_solver.mdx | 2 +- docs/source/en/api/schedulers/overview.mdx | 2 +- docs/source/en/api/schedulers/pndm.mdx | 2 +- docs/source/en/api/schedulers/repaint.mdx | 2 +- .../source/en/api/schedulers/score_sde_ve.mdx | 2 +- .../source/en/api/schedulers/score_sde_vp.mdx | 2 +- .../api/schedulers/singlestep_dpm_solver.mdx | 2 +- .../api/schedulers/stochastic_karras_ve.mdx | 2 +- .../source/en/api/schedulers/vq_diffusion.mdx | 2 +- docs/source/en/conceptual/contribution.mdx | 2 +- docs/source/en/conceptual/philosophy.mdx | 2 +- docs/source/en/imgs/access_request.png | Bin 104814 -> 104814 bytes docs/source/en/index.mdx | 2 +- docs/source/en/installation.mdx | 2 +- docs/source/en/optimization/fp16.mdx | 2 +- docs/source/en/optimization/habana.mdx | 2 +- docs/source/en/optimization/mps.mdx | 2 +- docs/source/en/optimization/onnx.mdx | 2 +- docs/source/en/optimization/open_vino.mdx | 2 +- docs/source/en/optimization/xformers.mdx | 2 +- docs/source/en/quicktour.mdx | 2 +- docs/source/en/stable_diffusion.mdx | 4 ++-- docs/source/en/training/dreambooth.mdx | 2 +- docs/source/en/training/overview.mdx | 2 +- docs/source/en/training/text2image.mdx | 2 +- docs/source/en/training/text_inversion.mdx | 2 +- .../en/training/unconditional_training.mdx | 2 +- docs/source/en/using-diffusers/audio.mdx | 2 +- .../conditional_image_generation.mdx | 2 +- .../en/using-diffusers/configuration.mdx | 2 +- .../using-diffusers/contribute_pipeline.mdx | 2 +- .../custom_pipeline_examples.mdx | 2 +- .../custom_pipeline_overview.mdx | 2 +- docs/source/en/using-diffusers/depth2img.mdx | 2 +- docs/source/en/using-diffusers/img2img.mdx | 2 +- docs/source/en/using-diffusers/inpaint.mdx | 2 +- docs/source/en/using-diffusers/loading.mdx | 2 +- .../en/using-diffusers/other-modalities.mdx | 2 +- .../en/using-diffusers/reusing_seeds.mdx | 2 +- docs/source/en/using-diffusers/rl.mdx | 2 +- docs/source/en/using-diffusers/schedulers.mdx | 2 +- .../unconditional_image_generation.mdx | 2 +- docs/source/ko/in_translation.mdx | 2 +- docs/source/ko/index.mdx | 2 +- docs/source/ko/installation.mdx | 2 +- docs/source/ko/quicktour.mdx | 2 +- examples/README.md | 2 +- .../community/composable_stable_diffusion.py | 2 +- examples/community/sd_text2img_k_diffusion.py | 2 +- examples/community/tiled_upscaling.py | 2 +- examples/conftest.py | 2 +- examples/test_examples.py | 2 +- .../change_naming_configs_and_checkpoints.py | 2 +- ...rt_ldm_original_checkpoint_to_diffusers.py | 2 +- ...ncsnpp_original_checkpoint_to_diffusers.py | 2 +- ..._original_stable_diffusion_to_diffusers.py | 2 +- ...ert_stable_diffusion_checkpoint_to_onnx.py | 2 +- ...onvert_versatile_diffusion_to_diffusers.py | 2 +- setup.py | 2 +- src/diffusers/commands/__init__.py | 2 +- src/diffusers/commands/diffusers_cli.py | 2 +- src/diffusers/commands/env.py | 2 +- src/diffusers/configuration_utils.py | 2 +- src/diffusers/dependency_versions_check.py | 2 +- .../experimental/rl/value_guided_sampling.py | 2 +- src/diffusers/loaders.py | 2 +- src/diffusers/models/__init__.py | 2 +- src/diffusers/models/attention.py | 2 +- src/diffusers/models/attention_flax.py | 2 +- src/diffusers/models/autoencoder_kl.py | 2 +- src/diffusers/models/cross_attention.py | 2 +- src/diffusers/models/dual_transformer_2d.py | 2 +- src/diffusers/models/embeddings.py | 2 +- src/diffusers/models/embeddings_flax.py | 2 +- .../models/modeling_flax_pytorch_utils.py | 2 +- src/diffusers/models/modeling_flax_utils.py | 2 +- .../models/modeling_pytorch_flax_utils.py | 2 +- src/diffusers/models/modeling_utils.py | 2 +- src/diffusers/models/resnet_flax.py | 2 +- src/diffusers/models/transformer_2d.py | 2 +- src/diffusers/models/unet_1d.py | 2 +- src/diffusers/models/unet_1d_blocks.py | 2 +- src/diffusers/models/unet_2d.py | 2 +- src/diffusers/models/unet_2d_blocks.py | 2 +- src/diffusers/models/unet_2d_blocks_flax.py | 2 +- src/diffusers/models/unet_2d_condition.py | 2 +- .../models/unet_2d_condition_flax.py | 2 +- src/diffusers/models/vae.py | 2 +- src/diffusers/models/vae_flax.py | 2 +- src/diffusers/models/vq_model.py | 2 +- src/diffusers/optimization.py | 2 +- src/diffusers/pipeline_utils.py | 2 +- .../alt_diffusion/pipeline_alt_diffusion.py | 2 +- .../pipeline_alt_diffusion_img2img.py | 2 +- .../pipelines/audio_diffusion/mel.py | 2 +- .../pipeline_audio_diffusion.py | 2 +- .../pipeline_dance_diffusion.py | 2 +- src/diffusers/pipelines/ddim/pipeline_ddim.py | 2 +- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 2 +- .../pipeline_latent_diffusion.py | 2 +- .../pipeline_latent_diffusion_uncond.py | 2 +- src/diffusers/pipelines/onnx_utils.py | 2 +- .../paint_by_example/image_encoder.py | 2 +- .../pipeline_paint_by_example.py | 2 +- .../pipelines/pipeline_flax_utils.py | 2 +- src/diffusers/pipelines/pipeline_utils.py | 2 +- src/diffusers/pipelines/pndm/pipeline_pndm.py | 2 +- .../pipelines/repaint/pipeline_repaint.py | 2 +- .../score_sde_ve/pipeline_score_sde_ve.py | 2 +- .../pipeline_cycle_diffusion.py | 2 +- .../pipeline_flax_stable_diffusion.py | 2 +- .../pipeline_flax_stable_diffusion_img2img.py | 2 +- .../pipeline_onnx_stable_diffusion.py | 2 +- .../pipeline_onnx_stable_diffusion_img2img.py | 2 +- .../pipeline_onnx_stable_diffusion_inpaint.py | 2 +- .../pipeline_stable_diffusion.py | 2 +- .../pipeline_stable_diffusion_depth2img.py | 2 +- ...peline_stable_diffusion_image_variation.py | 2 +- .../pipeline_stable_diffusion_img2img.py | 2 +- .../pipeline_stable_diffusion_inpaint.py | 2 +- ...ipeline_stable_diffusion_inpaint_legacy.py | 2 +- .../pipeline_stable_diffusion_k_diffusion.py | 2 +- .../pipeline_stable_diffusion_upscale.py | 2 +- .../stable_diffusion/safety_checker.py | 2 +- .../stable_diffusion/safety_checker_flax.py | 2 +- .../stable_diffusion_safe/safety_checker.py | 2 +- .../pipeline_stochastic_karras_ve.py | 2 +- .../pipelines/unclip/pipeline_unclip.py | 2 +- .../unclip/pipeline_unclip_image_variation.py | 2 +- src/diffusers/pipelines/unclip/text_proj.py | 2 +- ...ipeline_versatile_diffusion_dual_guided.py | 2 +- ...ine_versatile_diffusion_image_variation.py | 2 +- ...eline_versatile_diffusion_text_to_image.py | 2 +- .../vq_diffusion/pipeline_vq_diffusion.py | 2 +- src/diffusers/schedulers/__init__.py | 2 +- src/diffusers/schedulers/scheduling_ddim.py | 2 +- .../schedulers/scheduling_ddim_flax.py | 2 +- src/diffusers/schedulers/scheduling_ddpm.py | 2 +- .../schedulers/scheduling_ddpm_flax.py | 2 +- .../schedulers/scheduling_deis_multistep.py | 2 +- .../scheduling_dpmsolver_multistep.py | 2 +- .../scheduling_dpmsolver_multistep_flax.py | 2 +- .../scheduling_dpmsolver_singlestep.py | 2 +- .../scheduling_euler_ancestral_discrete.py | 2 +- .../schedulers/scheduling_euler_discrete.py | 2 +- .../schedulers/scheduling_heun_discrete.py | 2 +- src/diffusers/schedulers/scheduling_ipndm.py | 2 +- .../scheduling_k_dpm_2_ancestral_discrete.py | 2 +- .../schedulers/scheduling_k_dpm_2_discrete.py | 2 +- .../schedulers/scheduling_karras_ve.py | 2 +- .../schedulers/scheduling_karras_ve_flax.py | 2 +- .../schedulers/scheduling_lms_discrete.py | 2 +- .../scheduling_lms_discrete_flax.py | 2 +- src/diffusers/schedulers/scheduling_pndm.py | 2 +- .../schedulers/scheduling_pndm_flax.py | 2 +- .../schedulers/scheduling_repaint.py | 2 +- src/diffusers/schedulers/scheduling_sde_ve.py | 2 +- .../schedulers/scheduling_sde_ve_flax.py | 2 +- src/diffusers/schedulers/scheduling_sde_vp.py | 2 +- src/diffusers/schedulers/scheduling_unclip.py | 2 +- src/diffusers/schedulers/scheduling_utils.py | 2 +- .../schedulers/scheduling_utils_flax.py | 2 +- .../schedulers/scheduling_vq_diffusion.py | 2 +- src/diffusers/utils/__init__.py | 2 +- src/diffusers/utils/constants.py | 2 +- src/diffusers/utils/doc_utils.py | 2 +- src/diffusers/utils/dynamic_modules_utils.py | 2 +- src/diffusers/utils/hub_utils.py | 2 +- src/diffusers/utils/import_utils.py | 2 +- src/diffusers/utils/logging.py | 2 +- src/diffusers/utils/outputs.py | 2 +- src/diffusers/utils/torch_utils.py | 2 +- tests/conftest.py | 2 +- tests/fixtures/custom_pipeline/pipeline.py | 2 +- tests/fixtures/custom_pipeline/what_ever.py | 2 +- tests/models/test_models_unet_1d.py | 2 +- tests/models/test_models_unet_2d.py | 2 +- tests/models/test_models_unet_2d_condition.py | 2 +- tests/models/test_models_vae.py | 2 +- tests/models/test_models_vq.py | 2 +- .../altdiffusion/test_alt_diffusion.py | 2 +- .../test_alt_diffusion_img2img.py | 2 +- .../audio_diffusion/test_audio_diffusion.py | 2 +- .../dance_diffusion/test_dance_diffusion.py | 2 +- tests/pipelines/ddim/test_ddim.py | 2 +- tests/pipelines/ddpm/test_ddpm.py | 2 +- tests/pipelines/karras_ve/test_karras_ve.py | 2 +- .../latent_diffusion/test_latent_diffusion.py | 2 +- .../test_latent_diffusion_superresolution.py | 2 +- .../test_latent_diffusion_uncond.py | 2 +- .../paint_by_example/test_paint_by_example.py | 2 +- tests/pipelines/pndm/test_pndm.py | 2 +- tests/pipelines/repaint/test_repaint.py | 2 +- .../score_sde_ve/test_score_sde_ve.py | 2 +- .../stable_diffusion/test_cycle_diffusion.py | 2 +- .../test_onnx_stable_diffusion.py | 2 +- .../test_onnx_stable_diffusion_img2img.py | 2 +- .../test_onnx_stable_diffusion_inpaint.py | 2 +- ...st_onnx_stable_diffusion_inpaint_legacy.py | 2 +- .../stable_diffusion/test_stable_diffusion.py | 2 +- .../test_stable_diffusion_image_variation.py | 2 +- .../test_stable_diffusion_img2img.py | 2 +- .../test_stable_diffusion_inpaint.py | 2 +- .../test_stable_diffusion_inpaint_legacy.py | 2 +- .../test_stable_diffusion_k_diffusion.py | 2 +- .../test_stable_diffusion.py | 2 +- .../test_stable_diffusion_depth.py | 2 +- .../test_stable_diffusion_flax.py | 2 +- .../test_stable_diffusion_inpaint.py | 2 +- .../test_stable_diffusion_upscale.py | 2 +- .../test_stable_diffusion_v_pred.py | 2 +- .../test_safe_diffusion.py | 2 +- tests/pipelines/unclip/test_unclip.py | 2 +- .../unclip/test_unclip_image_variation.py | 2 +- .../test_versatile_diffusion_dual_guided.py | 2 +- ...est_versatile_diffusion_image_variation.py | 2 +- .../test_versatile_diffusion_mega.py | 2 +- .../test_versatile_diffusion_text_to_image.py | 2 +- .../vq_diffusion/test_vq_diffusion.py | 2 +- tests/repo_utils/test_check_copies.py | 2 +- tests/repo_utils/test_check_dummies.py | 2 +- tests/test_config.py | 2 +- tests/test_layers_utils.py | 2 +- tests/test_modeling_common.py | 2 +- tests/test_pipelines.py | 2 +- tests/test_pipelines_flax.py | 2 +- tests/test_scheduler.py | 2 +- tests/test_scheduler_flax.py | 2 +- tests/test_training.py | 2 +- tests/test_unet_2d_blocks.py | 2 +- tests/test_unet_blocks_common.py | 2 +- tests/test_utils.py | 2 +- utils/check_config_docstrings.py | 2 +- utils/check_copies.py | 2 +- utils/check_doc_toc.py | 2 +- utils/check_dummies.py | 2 +- utils/check_inits.py | 2 +- utils/check_repo.py | 2 +- utils/check_table.py | 2 +- utils/custom_init_isort.py | 2 +- utils/get_modified_files.py | 2 +- utils/print_env.py | 2 +- utils/stale.py | 2 +- 287 files changed, 287 insertions(+), 287 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6f725ae87946..3b051427a231 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,5 +1,5 @@