diff --git a/examples/community/README.md b/examples/community/README.md index 065b46f5410c..17cd34a5182d 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -1601,7 +1601,7 @@ pipe_images = mixing_pipeline( ![image_mixing_result](https://huggingface.co/datasets/TheDenk/images_mixing/resolve/main/boromir_gigachad.png) -### Stable Diffusion Mixture +### Stable Diffusion Mixture Tiling This pipeline uses the Mixture. Refer to the [Mixture](https://arxiv.org/abs/2302.02412) paper for more details. @@ -1672,4 +1672,38 @@ mask_image = Image.open(BytesIO(response.content)).convert("RGB") prompt = "a mecha robot sitting on a bench" image = pipe(prompt, image=input_image, mask_image=mask_image, strength=0.75,).images[0] image.save('tensorrt_inpaint_mecha_robot.png') -``` \ No newline at end of file +``` + +### Stable Diffusion Mixture Canvas + +This pipeline uses the Mixture. Refer to the [Mixture](https://arxiv.org/abs/2302.02412) paper for more details. + +```python +from PIL import Image +from diffusers import LMSDiscreteScheduler, DiffusionPipeline +from diffusers.pipelines.pipeline_utils import Image2ImageRegion, Text2ImageRegion, preprocess_image + + +# Load and preprocess guide image +iic_image = preprocess_image(Image.open("input_image.png").convert("RGB")) + +# Creater scheduler and model (similar to StableDiffusionPipeline) +scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) +pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler).to("cuda:0", custom_pipeline="mixture_canvas") +pipeline.to("cuda") + +# Mixture of Diffusers generation +output = pipeline( + canvas_height=800, + canvas_width=352, + regions=[ + Text2ImageRegion(0, 800, 0, 352, guidance_scale=8, + prompt=f"best quality, masterpiece, WLOP, sakimichan, art contest winner on pixiv, 8K, intricate details, wet effects, rain drops, ethereal, mysterious, futuristic, UHD, HDR, cinematic lighting, in a beautiful forest, rainy day, award winning, trending on artstation, beautiful confident cheerful young woman, wearing a futuristic sleeveless dress, ultra beautiful detailed eyes, hyper-detailed face, complex, perfect, model,  textured, chiaroscuro, professional make-up, realistic, figure in frame, "), + Image2ImageRegion(352-800, 352, 0, 352, reference_image=iic_image, strength=1.0), + ], + num_inference_steps=100, + seed=5525475061, +)["images"][0] +``` +![Input_Image](https://huggingface.co/datasets/kadirnar/diffusers_readme_images/resolve/main/input_image.png) +![mixture_canvas_results](https://huggingface.co/datasets/kadirnar/diffusers_readme_images/resolve/main/canvas.png) diff --git a/examples/community/mixture.py b/examples/community/mixture.py deleted file mode 100644 index 845ad76b6a2e..000000000000 --- a/examples/community/mixture.py +++ /dev/null @@ -1,401 +0,0 @@ -import inspect -from copy import deepcopy -from enum import Enum -from typing import List, Optional, Tuple, Union - -import torch -from ligo.segments import segment -from tqdm.auto import tqdm -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - -from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker -from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from diffusers.utils import logging - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> from diffusers import LMSDiscreteScheduler - >>> from mixdiff import StableDiffusionTilingPipeline - - >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) - >>> pipeline = StableDiffusionTilingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler) - >>> pipeline.to("cuda:0") - - >>> image = pipeline( - >>> prompt=[[ - >>> "A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece", - >>> "A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece", - >>> "An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece" - >>> ]], - >>> tile_height=640, - >>> tile_width=640, - >>> tile_row_overlap=0, - >>> tile_col_overlap=256, - >>> guidance_scale=8, - >>> seed=7178915308, - >>> num_inference_steps=50, - >>> )["images"][0] - ``` -""" - - -def _tile2pixel_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap): - """Given a tile row and column numbers returns the range of pixels affected by that tiles in the overall image - - Returns a tuple with: - - Starting coordinates of rows in pixel space - - Ending coordinates of rows in pixel space - - Starting coordinates of columns in pixel space - - Ending coordinates of columns in pixel space - """ - px_row_init = 0 if tile_row == 0 else tile_row * (tile_height - tile_row_overlap) - px_row_end = px_row_init + tile_height - px_col_init = 0 if tile_col == 0 else tile_col * (tile_width - tile_col_overlap) - px_col_end = px_col_init + tile_width - return px_row_init, px_row_end, px_col_init, px_col_end - - -def _pixel2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end): - """Translates coordinates in pixel space to coordinates in latent space""" - return px_row_init // 8, px_row_end // 8, px_col_init // 8, px_col_end // 8 - - -def _tile2latent_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap): - """Given a tile row and column numbers returns the range of latents affected by that tiles in the overall image - - Returns a tuple with: - - Starting coordinates of rows in latent space - - Ending coordinates of rows in latent space - - Starting coordinates of columns in latent space - - Ending coordinates of columns in latent space - """ - px_row_init, px_row_end, px_col_init, px_col_end = _tile2pixel_indices( - tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap - ) - return _pixel2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end) - - -def _tile2latent_exclusive_indices( - tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, rows, columns -): - """Given a tile row and column numbers returns the range of latents affected only by that tile in the overall image - - Returns a tuple with: - - Starting coordinates of rows in latent space - - Ending coordinates of rows in latent space - - Starting coordinates of columns in latent space - - Ending coordinates of columns in latent space - """ - row_init, row_end, col_init, col_end = _tile2latent_indices( - tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap - ) - row_segment = segment(row_init, row_end) - col_segment = segment(col_init, col_end) - # Iterate over the rest of tiles, clipping the region for the current tile - for row in range(rows): - for column in range(columns): - if row != tile_row and column != tile_col: - clip_row_init, clip_row_end, clip_col_init, clip_col_end = _tile2latent_indices( - row, column, tile_width, tile_height, tile_row_overlap, tile_col_overlap - ) - row_segment = row_segment - segment(clip_row_init, clip_row_end) - col_segment = col_segment - segment(clip_col_init, clip_col_end) - # return row_init, row_end, col_init, col_end - return row_segment[0], row_segment[1], col_segment[0], col_segment[1] - - -class StableDiffusionExtrasMixin: - """Mixin providing additional convenience method to Stable Diffusion pipelines""" - - def decode_latents(self, latents, cpu_vae=False): - """Decodes a given array of latents into pixel space""" - # scale and decode the image latents with vae - if cpu_vae: - lat = deepcopy(latents).cpu() - vae = deepcopy(self.vae).cpu() - else: - lat = latents - vae = self.vae - - lat = 1 / 0.18215 * lat - image = vae.decode(lat).sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - - return self.numpy_to_pil(image) - - -class StableDiffusionTilingPipeline(DiffusionPipeline, StableDiffusionExtrasMixin): - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler], - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - ): - super().__init__() - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - - class SeedTilesMode(Enum): - """Modes in which the latents of a particular tile can be re-seeded""" - - FULL = "full" - EXCLUSIVE = "exclusive" - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[List[str]]], - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - eta: Optional[float] = 0.0, - seed: Optional[int] = None, - tile_height: Optional[int] = 512, - tile_width: Optional[int] = 512, - tile_row_overlap: Optional[int] = 256, - tile_col_overlap: Optional[int] = 256, - guidance_scale_tiles: Optional[List[List[float]]] = None, - seed_tiles: Optional[List[List[int]]] = None, - seed_tiles_mode: Optional[Union[str, List[List[str]]]] = "full", - seed_reroll_regions: Optional[List[Tuple[int, int, int, int, int]]] = None, - cpu_vae: Optional[bool] = False, - ): - r""" - Function to run the diffusion pipeline with tiling support. - - Args: - prompt: either a single string (no tiling) or a list of lists with all the prompts to use (one list for each row of tiles). This will also define the tiling structure. - num_inference_steps: number of diffusions steps. - guidance_scale: classifier-free guidance. - seed: general random seed to initialize latents. - tile_height: height in pixels of each grid tile. - tile_width: width in pixels of each grid tile. - tile_row_overlap: number of overlap pixels between tiles in consecutive rows. - tile_col_overlap: number of overlap pixels between tiles in consecutive columns. - guidance_scale_tiles: specific weights for classifier-free guidance in each tile. - guidance_scale_tiles: specific weights for classifier-free guidance in each tile. If None, the value provided in guidance_scale will be used. - seed_tiles: specific seeds for the initialization latents in each tile. These will override the latents generated for the whole canvas using the standard seed parameter. - seed_tiles_mode: either "full" "exclusive". If "full", all the latents affected by the tile be overriden. If "exclusive", only the latents that are affected exclusively by this tile (and no other tiles) will be overrriden. - seed_reroll_regions: a list of tuples in the form (start row, end row, start column, end column, seed) defining regions in pixel space for which the latents will be overriden using the given seed. Takes priority over seed_tiles. - cpu_vae: the decoder from latent space to pixel space can require too mucho GPU RAM for large images. If you find out of memory errors at the end of the generation process, try setting this parameter to True to run the decoder in CPU. Slower, but should run without memory issues. - - Examples: - - Returns: - A PIL image with the generated image. - - """ - if not isinstance(prompt, list) or not all(isinstance(row, list) for row in prompt): - raise ValueError(f"`prompt` has to be a list of lists but is {type(prompt)}") - grid_rows = len(prompt) - grid_cols = len(prompt[0]) - if not all(len(row) == grid_cols for row in prompt): - raise ValueError("All prompt rows must have the same number of prompt columns") - if not isinstance(seed_tiles_mode, str) and ( - not isinstance(seed_tiles_mode, list) or not all(isinstance(row, list) for row in seed_tiles_mode) - ): - raise ValueError(f"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}") - if isinstance(seed_tiles_mode, str): - seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt] - modes = [mode.value for mode in self.SeedTilesMode] - if any(mode not in modes for row in seed_tiles_mode for mode in row): - raise ValueError(f"Seed tiles mode must be one of {modes}") - if seed_reroll_regions is None: - seed_reroll_regions = [] - batch_size = 1 - - # create original noisy latents using the timesteps - height = tile_height + (grid_rows - 1) * (tile_height - tile_row_overlap) - width = tile_width + (grid_cols - 1) * (tile_width - tile_col_overlap) - latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8) - generator = torch.Generator("cuda").manual_seed(seed) - latents = torch.randn(latents_shape, generator=generator, device=self.device) - - # overwrite latents for specific tiles if provided - if seed_tiles is not None: - for row in range(grid_rows): - for col in range(grid_cols): - if (seed_tile := seed_tiles[row][col]) is not None: - mode = seed_tiles_mode[row][col] - if mode == self.SeedTilesMode.FULL.value: - row_init, row_end, col_init, col_end = _tile2latent_indices( - row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap - ) - else: - row_init, row_end, col_init, col_end = _tile2latent_exclusive_indices( - row, - col, - tile_width, - tile_height, - tile_row_overlap, - tile_col_overlap, - grid_rows, - grid_cols, - ) - tile_generator = torch.Generator("cuda").manual_seed(seed_tile) - tile_shape = (latents_shape[0], latents_shape[1], row_end - row_init, col_end - col_init) - latents[:, :, row_init:row_end, col_init:col_end] = torch.randn( - tile_shape, generator=tile_generator, device=self.device - ) - - # overwrite again for seed reroll regions - for row_init, row_end, col_init, col_end, seed_reroll in seed_reroll_regions: - row_init, row_end, col_init, col_end = _pixel2latent_indices( - row_init, row_end, col_init, col_end - ) # to latent space coordinates - reroll_generator = torch.Generator("cuda").manual_seed(seed_reroll) - region_shape = (latents_shape[0], latents_shape[1], row_end - row_init, col_end - col_init) - latents[:, :, row_init:row_end, col_init:col_end] = torch.randn( - region_shape, generator=reroll_generator, device=self.device - ) - - # Prepare scheduler - accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) - extra_set_kwargs = {} - if accepts_offset: - extra_set_kwargs["offset"] = 1 - self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) - # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = latents * self.scheduler.sigmas[0] - - # get prompts text embeddings - text_input = [ - [ - self.tokenizer( - col, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - for col in row - ] - for row in prompt - ] - text_embeddings = [[self.text_encoder(col.input_ids.to(self.device))[0] for col in row] for row in text_input] - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 # TODO: also active if any tile has guidance scale - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - for i in range(grid_rows): - for j in range(grid_cols): - max_length = text_input[i][j].input_ids.shape[-1] - uncond_input = self.tokenizer( - [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" - ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - text_embeddings[i][j] = torch.cat([uncond_embeddings, text_embeddings[i][j]]) - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # Mask for tile weights strenght - tile_weights = self._gaussian_weights(tile_width, tile_height, batch_size) - - # Diffusion timesteps - for i, t in tqdm(enumerate(self.scheduler.timesteps)): - # Diffuse each tile - noise_preds = [] - for row in range(grid_rows): - noise_preds_row = [] - for col in range(grid_cols): - px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices( - row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap - ) - tile_latents = latents[:, :, px_row_init:px_row_end, px_col_init:px_col_end] - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([tile_latents] * 2) if do_classifier_free_guidance else tile_latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings[row][col])[ - "sample" - ] - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - guidance = ( - guidance_scale - if guidance_scale_tiles is None or guidance_scale_tiles[row][col] is None - else guidance_scale_tiles[row][col] - ) - noise_pred_tile = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond) - noise_preds_row.append(noise_pred_tile) - noise_preds.append(noise_preds_row) - # Stitch noise predictions for all tiles - noise_pred = torch.zeros(latents.shape, device=self.device) - contributors = torch.zeros(latents.shape, device=self.device) - # Add each tile contribution to overall latents - for row in range(grid_rows): - for col in range(grid_cols): - px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices( - row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap - ) - noise_pred[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += ( - noise_preds[row][col] * tile_weights - ) - contributors[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += tile_weights - # Average overlapping areas with more than 1 contributor - noise_pred /= contributors - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents).prev_sample - - # scale and decode the image latents with vae - image = self.decode_latents(latents, cpu_vae) - - return {"images": image} - - def _gaussian_weights(self, tile_width, tile_height, nbatches): - """Generates a gaussian mask of weights for tile contributions""" - import numpy as np - from numpy import exp, pi, sqrt - - latent_width = tile_width // 8 - latent_height = tile_height // 8 - - var = 0.01 - midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1 - x_probs = [ - exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / sqrt(2 * pi * var) - for x in range(latent_width) - ] - midpoint = latent_height / 2 - y_probs = [ - exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / sqrt(2 * pi * var) - for y in range(latent_height) - ] - - weights = np.outer(y_probs, x_probs) - return torch.tile(torch.tensor(weights, device=self.device), (nbatches, self.unet.config.in_channels, 1, 1)) diff --git a/examples/community/mixture_canvas.py b/examples/community/mixture_canvas.py new file mode 100644 index 000000000000..40139d1139ad --- /dev/null +++ b/examples/community/mixture_canvas.py @@ -0,0 +1,503 @@ +import re +from copy import deepcopy +from dataclasses import asdict, dataclass +from enum import Enum +from typing import List, Optional, Union + +import numpy as np +import torch +from numpy import exp, pi, sqrt +from torchvision.transforms.functional import resize +from tqdm.auto import tqdm +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler + + +def preprocess_image(image): + from PIL import Image + + """Preprocess an input image + + Same as + https://github.com/huggingface/diffusers/blob/1138d63b519e37f0ce04e027b9f4a3261d27c628/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L44 + """ + w, h = image.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +@dataclass +class CanvasRegion: + """Class defining a rectangular region in the canvas""" + + row_init: int # Region starting row in pixel space (included) + row_end: int # Region end row in pixel space (not included) + col_init: int # Region starting column in pixel space (included) + col_end: int # Region end column in pixel space (not included) + region_seed: int = None # Seed for random operations in this region + noise_eps: float = 0.0 # Deviation of a zero-mean gaussian noise to be applied over the latents in this region. Useful for slightly "rerolling" latents + + def __post_init__(self): + # Initialize arguments if not specified + if self.region_seed is None: + self.region_seed = np.random.randint(9999999999) + # Check coordinates are non-negative + for coord in [self.row_init, self.row_end, self.col_init, self.col_end]: + if coord < 0: + raise ValueError( + f"A CanvasRegion must be defined with non-negative indices, found ({self.row_init}, {self.row_end}, {self.col_init}, {self.col_end})" + ) + # Check coordinates are divisible by 8, else we end up with nasty rounding error when mapping to latent space + for coord in [self.row_init, self.row_end, self.col_init, self.col_end]: + if coord // 8 != coord / 8: + raise ValueError( + f"A CanvasRegion must be defined with locations divisible by 8, found ({self.row_init}-{self.row_end}, {self.col_init}-{self.col_end})" + ) + # Check noise eps is non-negative + if self.noise_eps < 0: + raise ValueError(f"A CanvasRegion must be defined noises eps non-negative, found {self.noise_eps}") + # Compute coordinates for this region in latent space + self.latent_row_init = self.row_init // 8 + self.latent_row_end = self.row_end // 8 + self.latent_col_init = self.col_init // 8 + self.latent_col_end = self.col_end // 8 + + @property + def width(self): + return self.col_end - self.col_init + + @property + def height(self): + return self.row_end - self.row_init + + def get_region_generator(self, device="cpu"): + """Creates a torch.Generator based on the random seed of this region""" + # Initialize region generator + return torch.Generator(device).manual_seed(self.region_seed) + + @property + def __dict__(self): + return asdict(self) + + +class MaskModes(Enum): + """Modes in which the influence of diffuser is masked""" + + CONSTANT = "constant" + GAUSSIAN = "gaussian" + QUARTIC = "quartic" # See https://en.wikipedia.org/wiki/Kernel_(statistics) + + +@dataclass +class DiffusionRegion(CanvasRegion): + """Abstract class defining a region where some class of diffusion process is acting""" + + pass + + +@dataclass +class Text2ImageRegion(DiffusionRegion): + """Class defining a region where a text guided diffusion process is acting""" + + prompt: str = "" # Text prompt guiding the diffuser in this region + guidance_scale: float = 7.5 # Guidance scale of the diffuser in this region. If None, randomize + mask_type: MaskModes = MaskModes.GAUSSIAN.value # Kind of weight mask applied to this region + mask_weight: float = 1.0 # Global weights multiplier of the mask + tokenized_prompt = None # Tokenized prompt + encoded_prompt = None # Encoded prompt + + def __post_init__(self): + super().__post_init__() + # Mask weight cannot be negative + if self.mask_weight < 0: + raise ValueError( + f"A Text2ImageRegion must be defined with non-negative mask weight, found {self.mask_weight}" + ) + # Mask type must be an actual known mask + if self.mask_type not in [e.value for e in MaskModes]: + raise ValueError( + f"A Text2ImageRegion was defined with mask {self.mask_type}, which is not an accepted mask ({[e.value for e in MaskModes]})" + ) + # Randomize arguments if given as None + if self.guidance_scale is None: + self.guidance_scale = np.random.randint(5, 30) + # Clean prompt + self.prompt = re.sub(" +", " ", self.prompt).replace("\n", " ") + + def tokenize_prompt(self, tokenizer): + """Tokenizes the prompt for this diffusion region using a given tokenizer""" + self.tokenized_prompt = tokenizer( + self.prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + def encode_prompt(self, text_encoder, device): + """Encodes the previously tokenized prompt for this diffusion region using a given encoder""" + assert self.tokenized_prompt is not None, ValueError( + "Prompt in diffusion region must be tokenized before encoding" + ) + self.encoded_prompt = text_encoder(self.tokenized_prompt.input_ids.to(device))[0] + + +@dataclass +class Image2ImageRegion(DiffusionRegion): + """Class defining a region where an image guided diffusion process is acting""" + + reference_image: torch.FloatTensor = None + strength: float = 0.8 # Strength of the image + + def __post_init__(self): + super().__post_init__() + if self.reference_image is None: + raise ValueError("Must provide a reference image when creating an Image2ImageRegion") + if self.strength < 0 or self.strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {self.strength}") + # Rescale image to region shape + self.reference_image = resize(self.reference_image, size=[self.height, self.width]) + + def encode_reference_image(self, encoder, device, generator, cpu_vae=False): + """Encodes the reference image for this Image2Image region into the latent space""" + # Place encoder in CPU or not following the parameter cpu_vae + if cpu_vae: + # Note here we use mean instead of sample, to avoid moving also generator to CPU, which is troublesome + self.reference_latents = encoder.cpu().encode(self.reference_image).latent_dist.mean.to(device) + else: + self.reference_latents = encoder.encode(self.reference_image.to(device)).latent_dist.sample( + generator=generator + ) + self.reference_latents = 0.18215 * self.reference_latents + + @property + def __dict__(self): + # This class requires special casting to dict because of the reference_image tensor. Otherwise it cannot be casted to JSON + + # Get all basic fields from parent class + super_fields = {key: getattr(self, key) for key in DiffusionRegion.__dataclass_fields__.keys()} + # Pack other fields + return {**super_fields, "reference_image": self.reference_image.cpu().tolist(), "strength": self.strength} + + +class RerollModes(Enum): + """Modes in which the reroll regions operate""" + + RESET = "reset" # Completely reset the random noise in the region + EPSILON = "epsilon" # Alter slightly the latents in the region + + +@dataclass +class RerollRegion(CanvasRegion): + """Class defining a rectangular canvas region in which initial latent noise will be rerolled""" + + reroll_mode: RerollModes = RerollModes.RESET.value + + +@dataclass +class MaskWeightsBuilder: + """Auxiliary class to compute a tensor of weights for a given diffusion region""" + + latent_space_dim: int # Size of the U-net latent space + nbatch: int = 1 # Batch size in the U-net + + def compute_mask_weights(self, region: DiffusionRegion) -> torch.tensor: + """Computes a tensor of weights for a given diffusion region""" + MASK_BUILDERS = { + MaskModes.CONSTANT.value: self._constant_weights, + MaskModes.GAUSSIAN.value: self._gaussian_weights, + MaskModes.QUARTIC.value: self._quartic_weights, + } + return MASK_BUILDERS[region.mask_type](region) + + def _constant_weights(self, region: DiffusionRegion) -> torch.tensor: + """Computes a tensor of constant for a given diffusion region""" + latent_width = region.latent_col_end - region.latent_col_init + latent_height = region.latent_row_end - region.latent_row_init + return torch.ones(self.nbatch, self.latent_space_dim, latent_height, latent_width) * region.mask_weight + + def _gaussian_weights(self, region: DiffusionRegion) -> torch.tensor: + """Generates a gaussian mask of weights for tile contributions""" + latent_width = region.latent_col_end - region.latent_col_init + latent_height = region.latent_row_end - region.latent_row_init + + var = 0.01 + midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1 + x_probs = [ + exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / sqrt(2 * pi * var) + for x in range(latent_width) + ] + midpoint = (latent_height - 1) / 2 + y_probs = [ + exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / sqrt(2 * pi * var) + for y in range(latent_height) + ] + + weights = np.outer(y_probs, x_probs) * region.mask_weight + return torch.tile(torch.tensor(weights), (self.nbatch, self.latent_space_dim, 1, 1)) + + def _quartic_weights(self, region: DiffusionRegion) -> torch.tensor: + """Generates a quartic mask of weights for tile contributions + + The quartic kernel has bounded support over the diffusion region, and a smooth decay to the region limits. + """ + quartic_constant = 15.0 / 16.0 + + support = (np.array(range(region.latent_col_init, region.latent_col_end)) - region.latent_col_init) / ( + region.latent_col_end - region.latent_col_init - 1 + ) * 1.99 - (1.99 / 2.0) + x_probs = quartic_constant * np.square(1 - np.square(support)) + support = (np.array(range(region.latent_row_init, region.latent_row_end)) - region.latent_row_init) / ( + region.latent_row_end - region.latent_row_init - 1 + ) * 1.99 - (1.99 / 2.0) + y_probs = quartic_constant * np.square(1 - np.square(support)) + + weights = np.outer(y_probs, x_probs) * region.mask_weight + return torch.tile(torch.tensor(weights), (self.nbatch, self.latent_space_dim, 1, 1)) + + +class StableDiffusionCanvasPipeline(DiffusionPipeline): + """Stable Diffusion pipeline that mixes several diffusers in the same canvas""" + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def decode_latents(self, latents, cpu_vae=False): + """Decodes a given array of latents into pixel space""" + # scale and decode the image latents with vae + if cpu_vae: + lat = deepcopy(latents).cpu() + vae = deepcopy(self.vae).cpu() + else: + lat = latents + vae = self.vae + + lat = 1 / 0.18215 * lat + image = vae.decode(lat).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + return self.numpy_to_pil(image) + + def get_latest_timestep_img2img(self, num_inference_steps, strength): + """Finds the latest timesteps where an img2img strength does not impose latents anymore""" + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * (1 - strength)) + offset + init_timestep = min(init_timestep, num_inference_steps) + + t_start = min(max(num_inference_steps - init_timestep + offset, 0), num_inference_steps - 1) + latest_timestep = self.scheduler.timesteps[t_start] + + return latest_timestep + + @torch.no_grad() + def __call__( + self, + canvas_height: int, + canvas_width: int, + regions: List[DiffusionRegion], + num_inference_steps: Optional[int] = 50, + seed: Optional[int] = 12345, + reroll_regions: Optional[List[RerollRegion]] = None, + cpu_vae: Optional[bool] = False, + decode_steps: Optional[bool] = False, + ): + if reroll_regions is None: + reroll_regions = [] + batch_size = 1 + + if decode_steps: + steps_images = [] + + # Prepare scheduler + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + + # Split diffusion regions by their kind + text2image_regions = [region for region in regions if isinstance(region, Text2ImageRegion)] + image2image_regions = [region for region in regions if isinstance(region, Image2ImageRegion)] + + # Prepare text embeddings + for region in text2image_regions: + region.tokenize_prompt(self.tokenizer) + region.encode_prompt(self.text_encoder, self.device) + + # Create original noisy latents using the timesteps + latents_shape = (batch_size, self.unet.config.in_channels, canvas_height // 8, canvas_width // 8) + generator = torch.Generator(self.device).manual_seed(seed) + init_noise = torch.randn(latents_shape, generator=generator, device=self.device) + + # Reset latents in seed reroll regions, if requested + for region in reroll_regions: + if region.reroll_mode == RerollModes.RESET.value: + region_shape = ( + latents_shape[0], + latents_shape[1], + region.latent_row_end - region.latent_row_init, + region.latent_col_end - region.latent_col_init, + ) + init_noise[ + :, + :, + region.latent_row_init : region.latent_row_end, + region.latent_col_init : region.latent_col_end, + ] = torch.randn(region_shape, generator=region.get_region_generator(self.device), device=self.device) + + # Apply epsilon noise to regions: first diffusion regions, then reroll regions + all_eps_rerolls = regions + [r for r in reroll_regions if r.reroll_mode == RerollModes.EPSILON.value] + for region in all_eps_rerolls: + if region.noise_eps > 0: + region_noise = init_noise[ + :, + :, + region.latent_row_init : region.latent_row_end, + region.latent_col_init : region.latent_col_end, + ] + eps_noise = ( + torch.randn( + region_noise.shape, generator=region.get_region_generator(self.device), device=self.device + ) + * region.noise_eps + ) + init_noise[ + :, + :, + region.latent_row_init : region.latent_row_end, + region.latent_col_init : region.latent_col_end, + ] += eps_noise + + # scale the initial noise by the standard deviation required by the scheduler + latents = init_noise * self.scheduler.init_noise_sigma + + # Get unconditional embeddings for classifier free guidance in text2image regions + for region in text2image_regions: + max_length = region.tokenized_prompt.input_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + region.encoded_prompt = torch.cat([uncond_embeddings, region.encoded_prompt]) + + # Prepare image latents + for region in image2image_regions: + region.encode_reference_image(self.vae, device=self.device, generator=generator) + + # Prepare mask of weights for each region + mask_builder = MaskWeightsBuilder(latent_space_dim=self.unet.config.in_channels, nbatch=batch_size) + mask_weights = [mask_builder.compute_mask_weights(region).to(self.device) for region in text2image_regions] + + # Diffusion timesteps + for i, t in tqdm(enumerate(self.scheduler.timesteps)): + # Diffuse each region + noise_preds_regions = [] + + # text2image regions + for region in text2image_regions: + region_latents = latents[ + :, + :, + region.latent_row_init : region.latent_row_end, + region.latent_col_init : region.latent_col_end, + ] + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([region_latents] * 2) + # scale model input following scheduler rules + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=region.encoded_prompt)["sample"] + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_region = noise_pred_uncond + region.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_preds_regions.append(noise_pred_region) + + # Merge noise predictions for all tiles + noise_pred = torch.zeros(latents.shape, device=self.device) + contributors = torch.zeros(latents.shape, device=self.device) + # Add each tile contribution to overall latents + for region, noise_pred_region, mask_weights_region in zip( + text2image_regions, noise_preds_regions, mask_weights + ): + noise_pred[ + :, + :, + region.latent_row_init : region.latent_row_end, + region.latent_col_init : region.latent_col_end, + ] += ( + noise_pred_region * mask_weights_region + ) + contributors[ + :, + :, + region.latent_row_init : region.latent_row_end, + region.latent_col_init : region.latent_col_end, + ] += mask_weights_region + # Average overlapping areas with more than 1 contributor + noise_pred /= contributors + noise_pred = torch.nan_to_num( + noise_pred + ) # Replace NaNs by zeros: NaN can appear if a position is not covered by any DiffusionRegion + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + # Image2Image regions: override latents generated by the scheduler + for region in image2image_regions: + influence_step = self.get_latest_timestep_img2img(num_inference_steps, region.strength) + # Only override in the timesteps before the last influence step of the image (given by its strength) + if t > influence_step: + timestep = t.repeat(batch_size) + region_init_noise = init_noise[ + :, + :, + region.latent_row_init : region.latent_row_end, + region.latent_col_init : region.latent_col_end, + ] + region_latents = self.scheduler.add_noise(region.reference_latents, region_init_noise, timestep) + latents[ + :, + :, + region.latent_row_init : region.latent_row_end, + region.latent_col_init : region.latent_col_end, + ] = region_latents + + if decode_steps: + steps_images.append(self.decode_latents(latents, cpu_vae)) + + # scale and decode the image latents with vae + image = self.decode_latents(latents, cpu_vae) + + output = {"images": image} + if decode_steps: + output = {**output, "steps_images": steps_images} + return output