Skip to content

Commit

Permalink
update readme (and run formatter on latent_consistency_img2img.py)
Browse files Browse the repository at this point in the history
  • Loading branch information
nagolinc committed Oct 23, 2023
1 parent 0ed52c4 commit 072e682
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 12 deletions.
33 changes: 33 additions & 0 deletions examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback fr
sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | - | [Umer H. Adil](https://twitter.com/UmerHAdil) |
| Latent Consistency Pipeline | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) | [Latent Consistency Pipeline](#latent-consistency-pipeline) | - | [Simian Luo](https://github.com/luosiallen) |
| Latent Consistency Img2img Pipeline | Img2img pipeline for Latent Consistency Models | [Latent Consistency Img2Img Pipeline](#latent-consistency-img2img-pipeline) | - | [Logan Zoellner](https://github.com/nagolinc) |


To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
Expand Down Expand Up @@ -2185,3 +2186,35 @@ images = pipe(prompt=prompt, num_inference_steps=num_inference_steps, guidance_s
For any questions or feedback, feel free to reach out to [Simian Luo](https://github.com/luosiallen).

You can also try this pipeline directly in the [🚀 official spaces](https://huggingface.co/spaces/SimianLuo/Latent_Consistency_Model).



### Latent Consistency Img2img Pipeline

This pipeline extends the Latent Consistency Pipeline to allow it to take an input image.

```py
from diffusers import DiffusionPipeline
import torch

pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", custom_pipeline="latent_consistency_img2img")

# To save GPU memory, torch.float16 can be used, but it may compromise image quality.
pipe.to(torch_device="cuda", torch_dtype=torch.float32)
```

- 2. Run inference with as little as 4 steps:

```py
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"


input_image=Image.open("myimg.png")

strength = 0.5 #strength =0 (no change) strength=1 (completely overwrite image)

# Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
num_inference_steps = 4

images = pipe(prompt=prompt, image=input_image, strength=strength, num_inference_steps=num_inference_steps, guidance_scale=8.0, lcm_origin_steps=50, output_type="pil").images
```
21 changes: 9 additions & 12 deletions examples/community/latent_consistency_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,15 @@ def run_safety_checker(self, image, device, dtype):

def prepare_latents(self, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, latents=None, generator=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)

if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)

image = image.to(device=device, dtype=dtype)

#batch_size = batch_size * num_images_per_prompt
# batch_size = batch_size * num_images_per_prompt

if image.shape[1] == 4:
init_latents = image
Expand Down Expand Up @@ -207,7 +207,7 @@ def prepare_latents(self, image, timestep, batch_size, num_channels_latents, hei
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
" your script to pass as many initial images as text prompts to suppress this warning."
)
#deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
# deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
additional_image_per_prompt = batch_size // init_latents.shape[0]
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
Expand All @@ -225,10 +225,7 @@ def prepare_latents(self, image, timestep, batch_size, num_channels_latents, hei
latents = init_latents

return latents





if latents is None:
latents = torch.randn(shape, dtype=dtype).to(device)
else:
Expand Down Expand Up @@ -259,7 +256,7 @@ def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32):
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb

def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
Expand Down Expand Up @@ -314,9 +311,9 @@ def __call__(
image = self.image_processor.preprocess(image)

# 4. Prepare timesteps
self.scheduler.set_timesteps(strength,num_inference_steps, lcm_origin_steps)
#timesteps = self.scheduler.timesteps
#timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, 1.0, device)
self.scheduler.set_timesteps(strength, num_inference_steps, lcm_origin_steps)
# timesteps = self.scheduler.timesteps
# timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, 1.0, device)
timesteps = self.scheduler.timesteps
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)

Expand Down Expand Up @@ -665,7 +662,7 @@ def set_timesteps(self, stength, num_inference_steps: int, lcm_origin_steps: int

# LCM Timesteps Setting: # Linear Spacing
c = self.config.num_train_timesteps // lcm_origin_steps
lcm_origin_timesteps = np.asarray(list(range(1, int(lcm_origin_steps*stength) + 1))) * c - 1 # LCM Training Steps Schedule
lcm_origin_timesteps = np.asarray(list(range(1, int(lcm_origin_steps * stength) + 1))) * c - 1 # LCM Training Steps Schedule
skipping_step = len(lcm_origin_timesteps) // num_inference_steps
timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps] # LCM Inference Steps Schedule

Expand Down

0 comments on commit 072e682

Please sign in to comment.