This is the official PyTorch implementation of T-Stitch: Accelerating Sampling in Pre-trained Diffusion Models with Trajectory Stitching
Zizheng Pan1, Bohan Zhuang1, De-An Huang2, Weili Nie2, Zhiding Yu2, Chaowei Xiao2,3, Jianfei Cai1, Anima Anandkumar 4
Monash University1, NVIDIA2, University of Wisconsin, Madison3, Caltech4
[Paper] [Project Page]
We introduce sampling Trajectory Stitching (T-Stitch), a simple yet efficient technique to improve the generation efficiency with little or no loss in the generation quality. Instead of solely using a large DPM for the entire sampling trajectory, T-Stitch first leverages a smaller DPM in the initial steps as a cheap drop-in replacement of the larger DPM and switches to the larger DPM at a later stage, thus achieving flexible speed and quality trade-offs.
One example of stitching more DiT-S steps to achieve faster sampling for DiT-XL, where the time cost is measured by generating 8 images on one RTX 3090 in seconds (s).
By directly adopting a small SD in the model zoo, T-Stitch naturally interpolates the speed, style, and image contents with a large styled SD, which also potentially improves the prompt alignment, e.g., “New York City” and “tropical beach” in the above examples.
T-Stitch is completely complementary to previous techniques that focus on reducing the sampling steps, e.g., directly reduce the number of steps, advanced samplers, distillation.
For basic usage with diffusers, you can create an environment following our provided requirements.txt
. Create a conda environment and activate it
conda create -n tstitch python=3.9 -y
conda activate tstitch
pip install -r requirements.txt
python gradio_demo.py
Please refer to the folder dit for detailed usage.
Please refer to the folder ldm for detailed usage.
Using T-Stitch for stable diffusion models is easy. At the root of this repo, do
import torch
from tstitch_sd_utils import get_tstitch_pipepline
import os
large_sd = "Envvi/Inkpunk-Diffusion"
small_sd = "nota-ai/bk-sdm-tiny"
pipe_sd = get_tstitch_pipepline(large_sd, small_sd)
prompt = 'a squirrel in the park, nvinkpunk style'
latent = torch.randn(1, 4, 64, 64, device="cuda", dtype=torch.float16)
save_dir = f'figures/inkpunk'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
ratios = [round(item, 1) for item in torch.arange(0, 1.1, 0.1).tolist()]
for ratio in ratios:
image = pipe_sd(prompt, unet_s_ratio=ratio, latents=latent, height=512, width=512).images[0]
image.save(f"{save_dir}/sample-ratio-{ratio}.png")
The above script will create images by gradually increasing the fraction of small sd at the early sampling steps. Please feel free to try other stylized SD and other prompts. Also note that both models are required to process latents of the same shape.
T-Stitch provides a smooth speed and quality trade-off between a compressed SSD-1B and the original SDXL. Try the following command for this demo,
python sdxl_demo.py
T-Stitch is compatible with Controlnet, for example,
To use canny edges with SDXL, run python sdxl_canny.py
To use depth images with SDXL, run python sdxl_depth.py
To use poses with SDXL, run python sdxl_pose.py
T-Stitch is compatible with step-distilled models such as LCM-SDXL to achieve further speedup. For example, by adopting a small LCM distilled SSD-1B, T-Stitch still obtains impressive speed and quality trade-offs. We provide a script to demonstrate this compatibility.
python sdxl_lcm_lora.py
Thanks to the open source codebases such as DiT, ADM, Diffusers and LDM. Our codebase is built on them.
T-Stitch is licensed under CC-BY-NC. See LICENSE.txt for details. Portions of the project are available under separate license terms: LDM is licensed under the MIT License.