Skip to content

Commit

Permalink
Add Latent Consistency Models Pipeline (#5448)
Browse files Browse the repository at this point in the history
* initial commit for LatentConsistencyModelPipeline and LCMScheduler based on the community pipeline

* Add callback and freeu support.

* apply suggestions from review

* Clean up LCMScheduler

* Remove timeindex argument to LCMScheduler.step.

* Add support for clipping or thresholding the predicted original sample.

* Remove unused methods and arguments in LCMScheduler.

* Improve comment about (lack of) negative prompt support.

* Change input guidance_scale to match the StableDiffusionPipeline (Imagen) CFG formulation.

* Move lcm_origin_steps from pipeline __call__ to LCMScheduler.__init__/config (as origin_steps).

* Fix typo when clipping/thresholding in LCMScheduler.

* Add some initial LCMScheduler tests.

* add type annotations from review

* Fix type annotation bug.

* Override test_add_noise_device in LCMSchedulerTest since hardcoded timesteps doesn't work under default settings.

* Add generator argument pipeline prepare_latents call.

* Cast LCMScheduler.timesteps to long in set_timesteps.

* Add onestep and multistep full loop scheduler tests.

* Set default height/width to None and don't hardcode guidance scale embedding dim.

* Add initial LatentConsistencyPipeline fast and slow tests.

* Add initial documentation for LatentConsistencyModelPipeline and LCMScheduler.

* Make remaining failing fast tests pass.

* make style

* Make original_inference_steps configurable from pipeline __call__ again.

* make style

* Remove guidance_rescale arg from pipeline __call__ since LCM currently doesn't support CFG.

* Make LCMScheduler defaults match config of LCM_Dreamshaper_v7 checkpoint.

* Fix LatentConsistencyPipeline slow tests and add dummy expected slices.

* Add checks for original_steps in LCMScheduler.set_timesteps.

* make fix-copies

* Improve LatentConsistencyModelPipeline docs.

* Apply suggestions from code review

Co-authored-by: Aryan V S <avs050602@gmail.com>

* Apply suggestions from code review

Co-authored-by: Aryan V S <avs050602@gmail.com>

* Apply suggestions from code review

Co-authored-by: Aryan V S <avs050602@gmail.com>

* Update src/diffusers/schedulers/scheduling_lcm.py

* Apply suggestions from code review

Co-authored-by: Aryan V S <avs050602@gmail.com>

* finish

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Aryan V S <avs050602@gmail.com>
  • Loading branch information
4 people authored Oct 24, 2023
1 parent 7c3a75a commit 958e17d
Show file tree
Hide file tree
Showing 14 changed files with 1,758 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@
title: Kandinsky
- local: api/pipelines/kandinsky_v22
title: Kandinsky 2.2
- local: api/pipelines/latent_consistency_models
title: Latent Consistency Models
- local: api/pipelines/latent_diffusion
title: Latent Diffusion
- local: api/pipelines/panorama
Expand Down Expand Up @@ -368,6 +370,8 @@
title: KDPM2AncestralDiscreteScheduler
- local: api/schedulers/dpm_discrete
title: KDPM2DiscreteScheduler
- local: api/schedulers/lcm
title: LCMScheduler
- local: api/schedulers/lms_discrete
title: LMSDiscreteScheduler
- local: api/schedulers/pndm
Expand Down
44 changes: 44 additions & 0 deletions docs/source/en/api/pipelines/latent_consistency_models.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Latent Consistency Models

Latent Consistency Models (LCMs) were proposed in [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) by Simian Luo, Yiqin Tan, Longbo Huang, Jian Li, and Hang Zhao.

The abstract of the [paper](https://arxiv.org/pdf/2310.04378.pdf) is as follows:

*Latent Diffusion models (LDMs) have achieved remarkable results in synthesizing high-resolution images. However, the iterative sampling process is computationally intensive and leads to slow generation. Inspired by Consistency Models (song et al.), we propose Latent Consistency Models (LCMs), enabling swift inference with minimal steps on any pre-trained LDMs, including Stable Diffusion (rombach et al). Viewing the guided reverse diffusion process as solving an augmented probability flow ODE (PF-ODE), LCMs are designed to directly predict the solution of such ODE in latent space, mitigating the need for numerous iterations and allowing rapid, high-fidelity sampling. Efficiently distilled from pre-trained classifier-free guided diffusion models, a high-quality 768 x 768 2~4-step LCM takes only 32 A100 GPU hours for training. Furthermore, we introduce Latent Consistency Fine-tuning (LCF), a novel method that is tailored for fine-tuning LCMs on customized image datasets. Evaluation on the LAION-5B-Aesthetics dataset demonstrates that LCMs achieve state-of-the-art text-to-image generation performance with few-step inference.*

A demo for the [SimianLuo/LCM_Dreamshaper_v7](https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7) checkpoint can be found [here](https://huggingface.co/spaces/SimianLuo/Latent_Consistency_Model).

This pipeline was contributed by [luosiallen](https://luosiallen.github.io/) and [dg845](https://github.com/dg845).

```python
import torch
from diffusers import DiffusionPipeline

pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", torch_dtype=torch.float32)

# To save GPU memory, torch.float16 can be used, but it may compromise image quality.
pipe.to(torch_device="cuda", torch_dtype=torch.float32)

prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"

# Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
num_inference_steps = 4

images = pipe(prompt=prompt, num_inference_steps=num_inference_steps, guidance_scale=8.0).images
```

## LatentConsistencyModelPipeline

[[autodoc]] LatentConsistencyModelPipeline
- all
- __call__
- enable_freeu
- disable_freeu
- enable_vae_slicing
- disable_vae_slicing
- enable_vae_tiling
- disable_vae_tiling

## StableDiffusionPipelineOutput

[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
9 changes: 9 additions & 0 deletions docs/source/en/api/schedulers/lcm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Latent Consistency Model Multistep Scheduler

## Overview

Multistep and onestep scheduler (Algorithm 3) introduced alongside latent consistency models in the paper [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) by Simian Luo, Yiqin Tan, Longbo Huang, Jian Li, and Hang Zhao.
This scheduler should be able to generate good samples from [`LatentConsistencyModelPipeline`] in 1-8 steps.

## LCMScheduler
[[autodoc]] LCMScheduler
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@
"KarrasVeScheduler",
"KDPM2AncestralDiscreteScheduler",
"KDPM2DiscreteScheduler",
"LCMScheduler",
"PNDMScheduler",
"RePaintScheduler",
"SchedulerMixin",
Expand Down Expand Up @@ -226,6 +227,7 @@
"KandinskyV22Pipeline",
"KandinskyV22PriorEmb2EmbPipeline",
"KandinskyV22PriorPipeline",
"LatentConsistencyModelPipeline",
"LDMTextToImagePipeline",
"MusicLDMPipeline",
"PaintByExamplePipeline",
Expand Down Expand Up @@ -499,6 +501,7 @@
KarrasVeScheduler,
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
LCMScheduler,
PNDMScheduler,
RePaintScheduler,
SchedulerMixin,
Expand Down Expand Up @@ -564,6 +567,7 @@
KandinskyV22Pipeline,
KandinskyV22PriorEmb2EmbPipeline,
KandinskyV22PriorPipeline,
LatentConsistencyModelPipeline,
LDMTextToImagePipeline,
MusicLDMPipeline,
PaintByExamplePipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
"KandinskyV22PriorEmb2EmbPipeline",
"KandinskyV22PriorPipeline",
]
_import_structure["latent_consistency_models"] = ["LatentConsistencyModelPipeline"]
_import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"])
_import_structure["musicldm"] = ["MusicLDMPipeline"]
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
Expand Down Expand Up @@ -331,6 +332,7 @@
KandinskyV22PriorEmb2EmbPipeline,
KandinskyV22PriorPipeline,
)
from .latent_consistency_models import LatentConsistencyModelPipeline
from .latent_diffusion import LDMTextToImagePipeline
from .musicldm import MusicLDMPipeline
from .paint_by_example import PaintByExamplePipeline
Expand Down
22 changes: 22 additions & 0 deletions src/diffusers/pipelines/latent_consistency_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import TYPE_CHECKING

from ...utils import (
_LazyModule,
)


_import_structure = {"pipeline_latent_consistency_models": ["LatentConsistencyModelPipeline"]}


if TYPE_CHECKING:
from .pipeline_latent_consistency_models import LatentConsistencyModelPipeline

else:
import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
Loading

0 comments on commit 958e17d

Please sign in to comment.