From ae475228db8923b0aeb453eb868f2de53d5ea34f Mon Sep 17 00:00:00 2001
From: xiaomile <15622388695@163.com>
Date: Tue, 12 Dec 2023 14:32:49 +0800
Subject: [PATCH 1/9] Adding support for Consistency Models
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
clean the old commit message,open a new pr for https://github.com/open-mmlab/mmagic/pull/2045
---
configs/consistency_models/README.md | 87 ++
configs/consistency_models/README_zh-CN.md | 88 ++
...odels_8xb256-imagenet1k-multistep-64x64.py | 46 +
..._models_8xb256-imagenet1k-onestep-64x64.py | 46 +
...ls_8xb32-LSUN-bedroom-multistep-256x256.py | 47 +
...dels_8xb32-LSUN-bedroom-onestep-256x256.py | 46 +
...models_8xb32-LSUN-cat-multistep-256x256.py | 47 +
...y_models_8xb32-LSUN-cat-onestep-256x256.py | 46 +
configs/consistency_models/metafile.yml | 52 ++
mmagic/apis/mmagic_inferencer.py | 1 +
mmagic/models/editors/__init__.py | 5 +-
.../editors/consistency_models/__init__.py | 5 +
.../consistency_models/consistencymodel.py | 301 +++++++
.../consistencymodel_modules.py | 761 ++++++++++++++++
.../consistencymodel_utils.py | 817 ++++++++++++++++++
.../test_consistency_models.py | 136 +++
16 files changed, 2530 insertions(+), 1 deletion(-)
create mode 100644 configs/consistency_models/README.md
create mode 100644 configs/consistency_models/README_zh-CN.md
create mode 100644 configs/consistency_models/consistency_models_8xb256-imagenet1k-multistep-64x64.py
create mode 100644 configs/consistency_models/consistency_models_8xb256-imagenet1k-onestep-64x64.py
create mode 100644 configs/consistency_models/consistency_models_8xb32-LSUN-bedroom-multistep-256x256.py
create mode 100644 configs/consistency_models/consistency_models_8xb32-LSUN-bedroom-onestep-256x256.py
create mode 100644 configs/consistency_models/consistency_models_8xb32-LSUN-cat-multistep-256x256.py
create mode 100644 configs/consistency_models/consistency_models_8xb32-LSUN-cat-onestep-256x256.py
create mode 100644 configs/consistency_models/metafile.yml
create mode 100644 mmagic/models/editors/consistency_models/__init__.py
create mode 100644 mmagic/models/editors/consistency_models/consistencymodel.py
create mode 100644 mmagic/models/editors/consistency_models/consistencymodel_modules.py
create mode 100644 mmagic/models/editors/consistency_models/consistencymodel_utils.py
create mode 100644 tests/test_models/test_editors/test_consistency_models/test_consistency_models.py
diff --git a/configs/consistency_models/README.md b/configs/consistency_models/README.md
new file mode 100644
index 0000000000..4658e699e5
--- /dev/null
+++ b/configs/consistency_models/README.md
@@ -0,0 +1,87 @@
+# Consistency Models (ICML'2023)
+
+> [Consistency Models](https://arxiv.org/abs/2303.01469)
+
+> **Task**: conditional
+
+
+
+## Abstract
+
+
+
+Diffusion models have significantly advanced the fields of image, audio, and video generation, but they depend on an iterative sampling process that causes slow generation. To overcome this limitation, we propose consistency models, a new family of models that generate high quality samples by directly mapping noise to data. They support fast one-step generation by design, while still allowing multistep sampling to trade compute for sample quality. They also support zero-shot data editing, such as image inpainting, colorization, and super-resolution, without requiring explicit training on these tasks. Consistency models can be trained either by distilling pre-trained diffusion models, or as standalone generative models altogether. Through extensive experiments, we demonstrate that they outperform existing distillation techniques for diffusion models in one- and few-step sampling, achieving the new state-of-the-art FID of 3.55 on CIFAR-10 and 6.20 on ImageNet 64x64 for one-step generation. When trained in isolation, consistency models become a new family of generative models that can outperform existing one-step, non-adversarial generative models on standard benchmarks such as CIFAR-10, ImageNet 64x64 and LSUN 256x256.
+
+
+
+
+
+## Pre-trained models
+
+| Model | Dataset | Conditional | Download |
+| :-------------------------------------------------------------------------------------------: | :--------: | :---------: | :------: |
+| [onestep on ImageNet-64](./consistency_models_8xb256-imagenet1k-onestep-64x64.py) | imagenet1k | yes | - |
+| [multistep on ImageNet-64](./consistency_models_8xb256-imagenet1k-multistep-64x64.py) | imagenet1k | yes | - |
+| [onestep on LSUN Bedroom-256](./consistency_models_8xb32-LSUN-bedroom-onestep-256x256.py) | LSUN | no | - |
+| [multistep on LSUN Bedroom-256](./consistency_models_8xb32-LSUN-bedroom-multistep-256x256.py) | LSUN | no | - |
+| [onstep on LSUN Cat-256](./consistency_models_8xb32-LSUN-cat-onestep-256x256.py) | LSUN | no | - |
+| [multistep on LSUN Cat-256](./consistency_models_8xb32-LSUN-cat-multistep-256x256.py) | LSUN | no | - |
+
+You can also download checkpoints which is the main models in the paper to local machine and deliver the path to 'model_path' before infer.
+Here are the download links for each model checkpoint:
+
+- EDM on ImageNet-64: [edm_imagenet64_ema.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/edm_imagenet64_ema.pt)
+- CD on ImageNet-64 with l2 metric: [cd_imagenet64_l2.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/cd_imagenet64_l2.pt)
+- CD on ImageNet-64 with LPIPS metric: [cd_imagenet64_lpips.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/cd_imagenet64_lpips.pt)
+- CT on ImageNet-64: [ct_imagenet64.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/ct_imagenet64.pt)
+- EDM on LSUN Bedroom-256: [edm_bedroom256_ema.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/edm_bedroom256_ema.pt)
+- CD on LSUN Bedroom-256 with l2 metric: [cd_bedroom256_l2.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/cd_bedroom256_l2.pt)
+- CD on LSUN Bedroom-256 with LPIPS metric: [cd_bedroom256_lpips.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/cd_bedroom256_lpips.pt)
+- CT on LSUN Bedroom-256: [ct_bedroom256.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/ct_bedroom256.pt)
+- EDM on LSUN Cat-256: [edm_cat256_ema.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/edm_cat256_ema.pt)
+- CD on LSUN Cat-256 with l2 metric: [cd_cat256_l2.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/cd_cat256_l2.pt)
+- CD on LSUN Cat-256 with LPIPS metric: [cd_cat256_lpips.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/cd_cat256_lpips.pt)
+- CT on LSUN Cat-256: [ct_cat256.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/ct_cat256.pt)
+
+## quick start
+
+**Infer**
+
+
+Infer Instructions
+
+You can use the following commands to infer with the model.
+
+```shell
+# onestep
+python demo\mmagic_inference_demo.py \
+ --model-name consistency_models \
+ --model-config configs/consistency_models/consistency_models_8xb256-imagenet1k-onestep-64x64.py \
+ --result-out-dir demo_consistency_model.jpg
+
+# multistep
+python demo\mmagic_inference_demo.py \
+ --model-name consistency_models \
+ --model-config configs/consistency_models/consistency_models_8xb256-imagenet1k-multistep-64x64.py \
+ --result-out-dir demo_consistency_model.jpg
+
+# conditional
+python demo\mmagic_inference_demo.py \
+ --model-name consistency_models \
+ --model-config configs/consistency_models/consistency_models_8xb256-imagenet1k-onestep-64x64.py \
+ --label 145 \
+ --result-out-dir demo_consistency_model.jpg
+```
+
+
+
+# Citation
+
+```bibtex
+@article{song2023consistency,
+ title={Consistency Models},
+ author={Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya},
+ journal={arXiv preprint arXiv:2303.01469},
+ year={2023},
+}
+```
diff --git a/configs/consistency_models/README_zh-CN.md b/configs/consistency_models/README_zh-CN.md
new file mode 100644
index 0000000000..ae6405397b
--- /dev/null
+++ b/configs/consistency_models/README_zh-CN.md
@@ -0,0 +1,88 @@
+# Consistency Models (ICML'2023)
+
+> [Consistency Models](https://arxiv.org/abs/2303.01469)
+
+> **任务**: 条件生成
+
+
+
+## 摘要
+
+
+
+扩散模型在图像、音频和视频生成领域取得了显著的进展,但它们依赖于迭代采样过程,导致生成速度较慢。为了克服这个限制,我们提出了一种新的模型家族——一致性模型,通过直接将噪声映射到数据来生成高质量的样本。它们通过设计支持快速的单步生成,同时仍然允许多步采样以在计算和样本质量之间进行权衡。它们还支持零样本数据编辑,如图像修补、上色和超分辨率,而不需要在这些任务上进行显式训练。一致性模型可以通过蒸馏预训练的扩散模型或作为独立的生成模型进行训练。通过大量实验证明,它们在单步和少步采样方面优于现有的扩散模型蒸馏技术,实现了 CIFAR-10 上的新的最先进 FID(Fréchet Inception Distance)为 3.55,ImageNet 64x64 上为 6.20 的结果。当独立训练时,一致性模型成为一种新的生成模型家族,在 CIFAR-10、ImageNet 64x64 和 LSUN 256x256 等标准基准测试上可以优于现有的单步非对抗性生成模型。
+
+
+
+
+
+## 预训练模型
+
+| Model | Dataset | Conditional | Download |
+| :-------------------------------------------------------------------------------------------: | :--------: | :---------: | :------: |
+| [onestep on ImageNet-64](./consistency_models_8xb256-imagenet1k-onestep-64x64.py) | imagenet1k | yes | - |
+| [multistep on ImageNet-64](./consistency_models_8xb256-imagenet1k-multistep-64x64.py) | imagenet1k | yes | - |
+| [onestep on LSUN Bedroom-256](./consistency_models_8xb32-LSUN-bedroom-onestep-256x256.py) | LSUN | no | - |
+| [multistep on LSUN Bedroom-256](./consistency_models_8xb32-LSUN-bedroom-multistep-256x256.py) | LSUN | no | - |
+| [onstep on LSUN Cat-256](./consistency_models_8xb32-LSUN-cat-onestep-256x256.py) | LSUN | no | - |
+| [multistep on LSUN Cat-256](./consistency_models_8xb32-LSUN-cat-multistep-256x256.py) | LSUN | no | - |
+
+你也可以在进行推理前先把论文中主要模型的权重下载到本地的机器上并将权重路径传给'model_path'。
+以下是每个模型权重的下载链接:
+
+- EDM on ImageNet-64: [edm_imagenet64_ema.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/edm_imagenet64_ema.pt)
+- CD on ImageNet-64 with l2 metric: [cd_imagenet64_l2.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/cd_imagenet64_l2.pt)
+- CD on ImageNet-64 with LPIPS metric: [cd_imagenet64_lpips.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/cd_imagenet64_lpips.pt)
+- CT on ImageNet-64: [ct_imagenet64.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/ct_imagenet64.pt)
+- EDM on LSUN Bedroom-256: [edm_bedroom256_ema.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/edm_bedroom256_ema.pt)
+- CD on LSUN Bedroom-256 with l2 metric: [cd_bedroom256_l2.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/cd_bedroom256_l2.pt)
+- CD on LSUN Bedroom-256 with LPIPS metric: [cd_bedroom256_lpips.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/cd_bedroom256_lpips.pt)
+- CT on LSUN Bedroom-256: [ct_bedroom256.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/ct_bedroom256.pt)
+- EDM on LSUN Cat-256: [edm_cat256_ema.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/edm_cat256_ema.pt)
+- CD on LSUN Cat-256 with l2 metric: [cd_cat256_l2.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/cd_cat256_l2.pt)
+- CD on LSUN Cat-256 with LPIPS metric: [cd_cat256_lpips.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/cd_cat256_lpips.pt)
+- CT on LSUN Cat-256: [ct_cat256.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/ct_cat256.pt)
+-
+
+## 快速开始
+
+**推理**
+
+
+推理说明
+
+您可以使用以下命令来使用该模型进行推理:
+
+```shell
+# 一步生成
+python demo\mmagic_inference_demo.py \
+ --model-name consistency_models \
+ --model-config configs/consistency_models/consistency_models_8xb256-imagenet1k-onestep-64x64.py \
+ --result-out-dir demo_consistency_model.jpg
+
+# 多步生成
+python demo\mmagic_inference_demo.py \
+ --model-name consistency_models \
+ --model-config configs/consistency_models/consistency_models_8xb256-imagenet1k-multistep-64x64.py \
+ --result-out-dir demo_consistency_model.jpg
+
+# 条件控制生成
+python demo\mmagic_inference_demo.py \
+ --model-name consistency_models \
+ --model-config configs/consistency_models/consistency_models_8xb256-imagenet1k-onestep-64x64.py \
+ --label 145 \
+ --result-out-dir demo_consistency_model.jpg
+```
+
+
+
+# Citation
+
+```bibtex
+@article{song2023consistency,
+ title={Consistency Models},
+ author={Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya},
+ journal={arXiv preprint arXiv:2303.01469},
+ year={2023},
+}
+```
diff --git a/configs/consistency_models/consistency_models_8xb256-imagenet1k-multistep-64x64.py b/configs/consistency_models/consistency_models_8xb256-imagenet1k-multistep-64x64.py
new file mode 100644
index 0000000000..974dc2f2c0
--- /dev/null
+++ b/configs/consistency_models/consistency_models_8xb256-imagenet1k-multistep-64x64.py
@@ -0,0 +1,46 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+_base_ = ['../_base_/default_runtime.py']
+
+denoiser_config = dict(
+ type='KarrasDenoiser',
+ sigma_data=0.5,
+ sigma_max=80.0,
+ sigma_min=0.002,
+ weight_schedule='uniform',
+)
+
+unet_config = dict(
+ type='ConsistencyUNetModel',
+ in_channels=3,
+ model_channels=192,
+ num_res_blocks=3,
+ dropout=0.0,
+ channel_mult='',
+ use_checkpoint=False,
+ use_fp16=False,
+ num_head_channels=64,
+ num_heads=4,
+ num_heads_upsample=-1,
+ resblock_updown=True,
+ use_new_attention_order=False,
+ use_scale_shift_norm=True)
+
+model = dict(
+ type='ConsistencyModel',
+ unet=unet_config,
+ denoiser=denoiser_config,
+ attention_resolutions='32,16,8',
+ batch_size=4,
+ class_cond=True,
+ generator='determ',
+ image_size=64,
+ learn_sigma=False,
+ model_path='https://download.openxlab.org.cn/models/xiaomile/'
+ 'consistency_models/weight/cd_imagenet64_l2.pt',
+ num_classes=1000,
+ sampler='multistep',
+ seed=42,
+ training_mode='consistency_distillation',
+ ts='0,22,39',
+ data_preprocessor=dict(
+ type='DataPreprocessor', mean=[127.5] * 3, std=[127.5] * 3))
diff --git a/configs/consistency_models/consistency_models_8xb256-imagenet1k-onestep-64x64.py b/configs/consistency_models/consistency_models_8xb256-imagenet1k-onestep-64x64.py
new file mode 100644
index 0000000000..b2450b88bd
--- /dev/null
+++ b/configs/consistency_models/consistency_models_8xb256-imagenet1k-onestep-64x64.py
@@ -0,0 +1,46 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+_base_ = ['../_base_/default_runtime.py']
+
+denoiser_config = dict(
+ type='KarrasDenoiser',
+ sigma_data=0.5,
+ sigma_max=80.0,
+ sigma_min=0.002,
+ weight_schedule='uniform',
+)
+
+unet_config = dict(
+ type='ConsistencyUNetModel',
+ in_channels=3,
+ model_channels=192,
+ num_res_blocks=3,
+ dropout=0.0,
+ channel_mult='',
+ use_checkpoint=False,
+ use_fp16=False,
+ num_head_channels=64,
+ num_heads=4,
+ num_heads_upsample=-1,
+ resblock_updown=True,
+ use_new_attention_order=False,
+ use_scale_shift_norm=True)
+
+model = dict(
+ type='ConsistencyModel',
+ unet=unet_config,
+ denoiser=denoiser_config,
+ attention_resolutions='32,16,8',
+ batch_size=4,
+ class_cond=True,
+ generator='determ',
+ image_size=64,
+ learn_sigma=False,
+ model_path='https://download.openxlab.org.cn/models/xiaomile/'
+ 'consistency_models/weight/cd_imagenet64_l2.pt',
+ num_classes=1000,
+ sampler='onestep',
+ seed=42,
+ training_mode='consistency_distillation',
+ ts='',
+ data_preprocessor=dict(
+ type='DataPreprocessor', mean=[127.5] * 3, std=[127.5] * 3))
diff --git a/configs/consistency_models/consistency_models_8xb32-LSUN-bedroom-multistep-256x256.py b/configs/consistency_models/consistency_models_8xb32-LSUN-bedroom-multistep-256x256.py
new file mode 100644
index 0000000000..f9b554d4fc
--- /dev/null
+++ b/configs/consistency_models/consistency_models_8xb32-LSUN-bedroom-multistep-256x256.py
@@ -0,0 +1,47 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+_base_ = ['../_base_/default_runtime.py']
+
+denoiser_config = dict(
+ type='KarrasDenoiser',
+ sigma_data=0.5,
+ sigma_max=80.0,
+ sigma_min=0.002,
+ weight_schedule='uniform',
+)
+
+unet_config = dict(
+ type='ConsistencyUNetModel',
+ in_channels=3,
+ model_channels=256,
+ num_res_blocks=2,
+ dropout=0.0,
+ channel_mult='',
+ use_checkpoint=False,
+ use_fp16=False,
+ num_head_channels=64,
+ num_heads=4,
+ num_heads_upsample=-1,
+ resblock_updown=True,
+ use_new_attention_order=False,
+ use_scale_shift_norm=False)
+
+model = dict(
+ type='ConsistencyModel',
+ unet=unet_config,
+ denoiser=denoiser_config,
+ attention_resolutions='32,16,8',
+ batch_size=4,
+ class_cond=False,
+ generator='determ-indiv',
+ image_size=256,
+ learn_sigma=False,
+ model_path='https://download.openxlab.org.cn/models/xiaomile/'
+ 'consistency_models/weight/ct_bedroom256.pt',
+ num_classes=1000,
+ sampler='multistep',
+ seed=42,
+ training_mode='consistency_distillation',
+ ts='0,67,150',
+ steps=151,
+ data_preprocessor=dict(
+ type='DataPreprocessor', mean=[127.5] * 3, std=[127.5] * 3))
diff --git a/configs/consistency_models/consistency_models_8xb32-LSUN-bedroom-onestep-256x256.py b/configs/consistency_models/consistency_models_8xb32-LSUN-bedroom-onestep-256x256.py
new file mode 100644
index 0000000000..a6829f7b7b
--- /dev/null
+++ b/configs/consistency_models/consistency_models_8xb32-LSUN-bedroom-onestep-256x256.py
@@ -0,0 +1,46 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+_base_ = ['../_base_/default_runtime.py']
+
+denoiser_config = dict(
+ type='KarrasDenoiser',
+ sigma_data=0.5,
+ sigma_max=80.0,
+ sigma_min=0.002,
+ weight_schedule='uniform',
+)
+
+unet_config = dict(
+ type='ConsistencyUNetModel',
+ in_channels=3,
+ model_channels=256,
+ num_res_blocks=2,
+ dropout=0.0,
+ channel_mult='',
+ use_checkpoint=False,
+ use_fp16=False,
+ num_head_channels=64,
+ num_heads=4,
+ num_heads_upsample=-1,
+ resblock_updown=True,
+ use_new_attention_order=False,
+ use_scale_shift_norm=False)
+
+model = dict(
+ type='ConsistencyModel',
+ unet=unet_config,
+ denoiser=denoiser_config,
+ attention_resolutions='32,16,8',
+ batch_size=4,
+ class_cond=False,
+ generator='determ-indiv',
+ image_size=256,
+ learn_sigma=False,
+ model_path='https://download.openxlab.org.cn/models/xiaomile/'
+ 'consistency_models/weight/ct_bedroom256.pt',
+ num_classes=1000,
+ sampler='onestep',
+ seed=42,
+ training_mode='consistency_distillation',
+ ts='',
+ data_preprocessor=dict(
+ type='DataPreprocessor', mean=[127.5] * 3, std=[127.5] * 3))
diff --git a/configs/consistency_models/consistency_models_8xb32-LSUN-cat-multistep-256x256.py b/configs/consistency_models/consistency_models_8xb32-LSUN-cat-multistep-256x256.py
new file mode 100644
index 0000000000..df146cec77
--- /dev/null
+++ b/configs/consistency_models/consistency_models_8xb32-LSUN-cat-multistep-256x256.py
@@ -0,0 +1,47 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+_base_ = ['../_base_/default_runtime.py']
+
+denoiser_config = dict(
+ type='KarrasDenoiser',
+ sigma_data=0.5,
+ sigma_max=80.0,
+ sigma_min=0.002,
+ weight_schedule='uniform',
+)
+
+unet_config = dict(
+ type='ConsistencyUNetModel',
+ in_channels=3,
+ model_channels=256,
+ num_res_blocks=2,
+ dropout=0.0,
+ channel_mult='',
+ use_checkpoint=False,
+ use_fp16=False,
+ num_head_channels=64,
+ num_heads=4,
+ num_heads_upsample=-1,
+ resblock_updown=True,
+ use_new_attention_order=False,
+ use_scale_shift_norm=False)
+
+model = dict(
+ type='ConsistencyModel',
+ unet=unet_config,
+ denoiser=denoiser_config,
+ attention_resolutions='32,16,8',
+ batch_size=4,
+ class_cond=False,
+ generator='determ-indiv',
+ image_size=256,
+ learn_sigma=False,
+ model_path='https://download.openxlab.org.cn/models/xiaomile/'
+ 'consistency_models/weight/ct_cat256.pt',
+ num_classes=1000,
+ sampler='multistep',
+ seed=42,
+ training_mode='consistency_distillation',
+ ts='0,62,150',
+ steps=151,
+ data_preprocessor=dict(
+ type='DataPreprocessor', mean=[127.5] * 3, std=[127.5] * 3))
diff --git a/configs/consistency_models/consistency_models_8xb32-LSUN-cat-onestep-256x256.py b/configs/consistency_models/consistency_models_8xb32-LSUN-cat-onestep-256x256.py
new file mode 100644
index 0000000000..08f938b2fc
--- /dev/null
+++ b/configs/consistency_models/consistency_models_8xb32-LSUN-cat-onestep-256x256.py
@@ -0,0 +1,46 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+_base_ = ['../_base_/default_runtime.py']
+
+denoiser_config = dict(
+ type='KarrasDenoiser',
+ sigma_data=0.5,
+ sigma_max=80.0,
+ sigma_min=0.002,
+ weight_schedule='uniform',
+)
+
+unet_config = dict(
+ type='ConsistencyUNetModel',
+ in_channels=3,
+ model_channels=256,
+ num_res_blocks=2,
+ dropout=0.0,
+ channel_mult='',
+ use_checkpoint=False,
+ use_fp16=False,
+ num_head_channels=64,
+ num_heads=4,
+ num_heads_upsample=-1,
+ resblock_updown=True,
+ use_new_attention_order=False,
+ use_scale_shift_norm=False)
+
+model = dict(
+ type='ConsistencyModel',
+ unet=unet_config,
+ denoiser=denoiser_config,
+ attention_resolutions='32,16,8',
+ batch_size=4,
+ class_cond=False,
+ generator='determ-indiv',
+ image_size=256,
+ learn_sigma=False,
+ model_path='https://download.openxlab.org.cn/models/xiaomile/'
+ 'consistency_models/weight/ct_cat256.pt',
+ num_classes=1000,
+ sampler='onestep',
+ seed=42,
+ training_mode='consistency_distillation',
+ ts='',
+ data_preprocessor=dict(
+ type='DataPreprocessor', mean=[127.5] * 3, std=[127.5] * 3))
diff --git a/configs/consistency_models/metafile.yml b/configs/consistency_models/metafile.yml
new file mode 100644
index 0000000000..3a92929bd8
--- /dev/null
+++ b/configs/consistency_models/metafile.yml
@@ -0,0 +1,52 @@
+Collections:
+- Name: Consistency Models
+ Paper:
+ Title: Consistency Models
+ URL: https://arxiv.org/abs/2303.01469
+ README: configs/consistency_models/README.md
+ Task:
+ - conditional
+ Year: 2023
+Models:
+- Config: configs/consistency_models/consistency_models_8xb256-imagenet1k-onestep-64x64.py
+ In Collection: Consistency Models
+ Name: consistency_models_8xb256-imagenet1k-onestep-64x64
+ Results:
+ - Dataset: imagenet1k
+ Metrics: {}
+ Task: conditional
+- Config: configs/consistency_models/consistency_models_8xb256-imagenet1k-multistep-64x64.py
+ In Collection: Consistency Models
+ Name: consistency_models_8xb256-imagenet1k-multistep-64x64
+ Results:
+ - Dataset: imagenet1k
+ Metrics: {}
+ Task: conditional
+- Config: configs/consistency_models/consistency_models_8xb32-LSUN-bedroom-onestep-256x256.py
+ In Collection: Consistency Models
+ Name: consistency_models_8xb32-LSUN-bedroom-onestep-256x256
+ Results:
+ - Dataset: LSUN
+ Metrics: {}
+ Task: conditional
+- Config: configs/consistency_models/consistency_models_8xb32-LSUN-bedroom-multistep-256x256.py
+ In Collection: Consistency Models
+ Name: consistency_models_8xb32-LSUN-bedroom-multistep-256x256
+ Results:
+ - Dataset: LSUN
+ Metrics: {}
+ Task: conditional
+- Config: configs/consistency_models/consistency_models_8xb32-LSUN-cat-onestep-256x256.py
+ In Collection: Consistency Models
+ Name: consistency_models_8xb32-LSUN-cat-onestep-256x256
+ Results:
+ - Dataset: LSUN
+ Metrics: {}
+ Task: conditional
+- Config: configs/consistency_models/consistency_models_8xb32-LSUN-cat-multistep-256x256.py
+ In Collection: Consistency Models
+ Name: consistency_models_8xb32-LSUN-cat-multistep-256x256
+ Results:
+ - Dataset: LSUN
+ Metrics: {}
+ Task: conditional
diff --git a/mmagic/apis/mmagic_inferencer.py b/mmagic/apis/mmagic_inferencer.py
index 5cb9e4cb50..97882a4566 100644
--- a/mmagic/apis/mmagic_inferencer.py
+++ b/mmagic/apis/mmagic_inferencer.py
@@ -50,6 +50,7 @@ class MMagicInferencer:
'biggan',
'sngan_proj',
'sagan',
+ 'consistency_models',
# unconditional models
'dcgan',
diff --git a/mmagic/models/editors/__init__.py b/mmagic/models/editors/__init__.py
index 0d1085a6b7..90307a1af7 100644
--- a/mmagic/models/editors/__init__.py
+++ b/mmagic/models/editors/__init__.py
@@ -6,6 +6,8 @@
from .basicvsr_plusplus_net import BasicVSRPlusPlusNet
from .biggan import BigGAN
from .cain import CAIN, CAINNet
+from .consistency_models import (ConsistencyModel, ConsistencyUNetModel,
+ KarrasDenoiser)
from .controlnet import ControlStableDiffusion
from .cyclegan import CycleGAN
from .dcgan import DCGAN
@@ -98,5 +100,6 @@
'ControlStableDiffusion', 'DreamBooth', 'TextualInversion', 'DeblurGanV2',
'DeblurGanV2Generator', 'DeblurGanV2Discriminator',
'StableDiffusionInpaint', 'ViCo', 'FastComposer', 'AnimateDiff',
- 'UNet3DConditionMotionModel', 'StableDiffusionXL'
+ 'UNet3DConditionMotionModel', 'StableDiffusionXL', 'ConsistencyModel',
+ 'ConsistencyUNetModel', 'KarrasDenoiser'
]
diff --git a/mmagic/models/editors/consistency_models/__init__.py b/mmagic/models/editors/consistency_models/__init__.py
new file mode 100644
index 0000000000..8053918c0a
--- /dev/null
+++ b/mmagic/models/editors/consistency_models/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .consistencymodel import ConsistencyModel
+from .consistencymodel_modules import ConsistencyUNetModel, KarrasDenoiser
+
+__all__ = ['ConsistencyModel', 'ConsistencyUNetModel', 'KarrasDenoiser']
diff --git a/mmagic/models/editors/consistency_models/consistencymodel.py b/mmagic/models/editors/consistency_models/consistencymodel.py
new file mode 100644
index 0000000000..eb31aefa19
--- /dev/null
+++ b/mmagic/models/editors/consistency_models/consistencymodel.py
@@ -0,0 +1,301 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta
+from copy import deepcopy
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+from mmengine import Config
+from mmengine.model import BaseModel
+from torch.hub import load_state_dict_from_url
+
+from mmagic.registry import MODELS
+from mmagic.structures import DataSample
+from mmagic.utils import ForwardInputs
+from .consistencymodel_utils import (device, get_generator, get_sample_fn,
+ get_sigmas_karras, karras_sample)
+
+ModelType = Union[Dict, nn.Module]
+
+
+@MODELS.register_module()
+class ConsistencyModel(BaseModel, metaclass=ABCMeta):
+ """Implementation of `ConsistencyModel.
+
+ `_ (ConsistencyModel).
+ """
+
+ def __init__(self,
+ unet: ModelType,
+ denoiser: ModelType,
+ attention_resolutions: str = '32,16,8',
+ batch_size: int = 4,
+ channel_mult: str = '',
+ class_cond: Union[bool, int] = False,
+ generator: str = 'determ-indiv',
+ image_size: int = 256,
+ learn_sigma: bool = False,
+ model_path: Optional[str] = None,
+ num_classes: int = 0,
+ num_samples: int = 0,
+ sampler: str = 'onestep',
+ seed: int = 0,
+ sigma_max: float = 80.0,
+ sigma_min: float = 0.002,
+ training_mode: str = 'consistency_distillation',
+ ts: str = '',
+ clip_denoised: bool = True,
+ s_churn: float = 0.0,
+ s_noise: float = 1.0,
+ s_tmax: float = float('inf'),
+ s_tmin: float = 0.0,
+ steps: int = 40,
+ data_preprocessor: Optional[Union[dict, Config]] = None):
+
+ super().__init__(data_preprocessor=data_preprocessor)
+
+ self.num_classes = num_classes
+ if 'consistency' in training_mode:
+ self.distillation = True
+ else:
+ self.distillation = False
+ self.batch_size = batch_size
+ self.class_cond = class_cond
+ self.image_size = image_size
+ self.sampler = sampler
+ self.sigma_max = sigma_max
+ self.sigma_min = sigma_min
+ self.clip_denoised = clip_denoised
+ self.s_churn = s_churn
+ self.s_noise = s_noise
+ self.s_tmax = s_tmax
+ self.s_tmin = s_tmin
+ self.steps = steps
+ self.model_kwargs = {}
+
+ if channel_mult == '':
+ if image_size == 512:
+ channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
+ elif image_size == 256:
+ channel_mult = (1, 1, 2, 2, 4, 4)
+ elif image_size == 128:
+ channel_mult = (1, 1, 2, 3, 4)
+ elif image_size == 64:
+ channel_mult = (1, 2, 3, 4)
+ else:
+ raise ValueError(f'unsupported image size: {image_size}')
+ else:
+ channel_mult = tuple(
+ int(ch_mult) for ch_mult in channel_mult.split(','))
+
+ attention_ds = []
+ for res in attention_resolutions.split(','):
+ attention_ds.append(image_size // int(res))
+
+ if isinstance(unet, dict):
+ unet['image_size'] = image_size
+ unet['out_channels'] = (3 if not learn_sigma else 6)
+ unet['num_classes'] = (num_classes if class_cond else None)
+ unet['attention_resolutions'] = tuple(attention_ds)
+ unet['channel_mult'] = channel_mult
+ self.model = MODELS.build(unet)
+ else:
+ self.model = unet
+
+ if isinstance(denoiser, dict):
+ denoiser['distillation'] = self.distillation
+ self.diffusion = MODELS.build(denoiser)
+ else:
+ self.diffusion = denoiser
+
+ if model_path:
+ if 'https://' in model_path or 'http://' in model_path:
+ self.model.load_state_dict(
+ load_state_dict_from_url(model_path, map_location='cpu'))
+ else:
+ self.model.load_state_dict(
+ torch.load(model_path, map_location='cpu'))
+
+ self.model.to(device())
+
+ if sampler == 'multistep':
+ assert len(ts) > 0
+ self.ts = tuple(int(x) for x in ts.split(','))
+ else:
+ self.ts = None
+
+ self.all_images = []
+ self.all_labels = []
+ if num_samples <= 0:
+ self.num_samples = batch_size
+ else:
+ self.num_samples = num_samples
+ self.generator = get_generator(generator, self.num_samples, seed)
+
+ def infer(self, class_id: Optional[int] = None):
+ """infer with unet model and diffusion."""
+ self.model.eval()
+ while len(self.all_images) * self.batch_size < self.num_samples:
+ self.model_kwargs = {}
+ if self.class_cond:
+ classes = self.label_fn(class_id)
+ self.model_kwargs['y'] = classes
+ sample = karras_sample(
+ self.diffusion,
+ self.model,
+ (self.batch_size, 3, self.image_size, self.image_size),
+ steps=self.steps,
+ model_kwargs=self.model_kwargs,
+ device=device(),
+ clip_denoised=self.clip_denoised,
+ sampler=self.sampler,
+ sigma_min=self.sigma_min,
+ sigma_max=self.sigma_max,
+ s_churn=self.s_churn,
+ s_tmin=self.s_tmin,
+ s_tmax=self.s_tmax,
+ s_noise=self.s_noise,
+ generator=self.generator,
+ ts=self.ts,
+ )
+ sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
+ sample = sample.permute(0, 2, 3, 1)
+ sample = sample.contiguous()
+
+ self.all_images.extend([sample.cpu().numpy()])
+ if self.class_cond:
+ self.all_labels.extend([classes.cpu().numpy()])
+
+ arr = np.concatenate(self.all_images, axis=0)
+ arr = arr[:self.num_samples]
+ label_arr = []
+ if self.class_cond:
+ label_arr = np.concatenate(self.all_labels, axis=0)
+ label_arr = label_arr[:self.num_samples]
+
+ return arr, label_arr
+
+ def forward(self,
+ inputs: ForwardInputs,
+ data_samples: Optional[list] = None,
+ mode: Optional[str] = None) -> List[DataSample]:
+ """Sample images with the given inputs. If forward mode is 'ema' or
+ 'orig', the image generated by corresponding generator will be
+ returned. If forward mode is 'ema/orig', images generated by original
+ generator and EMA generator will both be returned in a dict.
+
+ Args:
+ inputs (ForwardInputs): Dict containing the necessary
+ information (e.g. noise, num_batches, mode) to generate image.
+ data_samples (Optional[list]): Data samples collated by
+ :attr:`data_preprocessor`. Defaults to None.
+ mode (Optional[str]): `mode` is not used in
+ :class:`BaseConditionalGAN`. Defaults to None.
+
+ Returns:
+ List[DataSample]: Generated images or image dict.
+ """
+
+ self.model_kwargs = {}
+ progress = False
+ callback = None
+ sample_kwargs = inputs.get('sample_kwargs', dict())
+ labels = self.label_fn(inputs.get('labels', 'None'))
+ if self.class_cond:
+ assert len(labels) > 0, 'If class_cond is True, ' \
+ 'labels\'s size should be over zero.'
+ self.model_kwargs['y'] = labels
+ sample_model = inputs.get('sample_model', None)
+ if self.generator is None:
+ self.generator = get_generator('dummy')
+
+ if self.sampler == 'progdist':
+ sigmas = get_sigmas_karras(
+ self.steps + 1,
+ self.sigma_min,
+ self.sigma_max,
+ self.diffusion.rho,
+ device=device())
+ else:
+ sigmas = get_sigmas_karras(
+ self.steps,
+ self.sigma_min,
+ self.sigma_max,
+ self.diffusion.rho,
+ device=device())
+
+ noise = self.generator.randn(
+ *(self.batch_size, 3, self.image_size, self.image_size),
+ device=device()) * self.sigma_max
+
+ sample_fn = get_sample_fn(self.sampler)
+
+ if self.sampler in ['heun', 'dpm']:
+ sampler_args = dict(
+ s_churn=self.s_churn,
+ s_tmin=self.s_tmin,
+ s_tmax=self.s_tmax,
+ s_noise=self.s_noise)
+ elif self.sampler == 'multistep':
+ sampler_args = dict(
+ ts=self.ts,
+ t_min=self.sigma_min,
+ t_max=self.sigma_max,
+ rho=self.diffusion.rho,
+ steps=self.steps)
+ else:
+ sampler_args = {}
+
+ outputs = sample_fn(
+ self.denoiser,
+ noise,
+ sigmas,
+ self.generator,
+ progress=progress,
+ callback=callback,
+ **sampler_args,
+ ).clamp(-1, 1)
+ outputs = self.data_preprocessor.destruct(outputs, data_samples)
+ outputs = self.data_preprocessor._do_conversion(outputs, 'BGR',
+ 'RGB')[0]
+
+ gen_sample = DataSample()
+ if data_samples:
+ gen_sample.update(data_samples)
+ gen_sample.fake_img = outputs
+ gen_sample.noise = noise
+ gen_sample.set_gt_label(labels)
+ gen_sample.sample_kwargs = deepcopy(sample_kwargs)
+ gen_sample.sample_model = sample_model
+ batch_sample_list = gen_sample.split(allow_nonseq_value=True)
+
+ return batch_sample_list
+
+ def denoiser(self, x_t, sigma):
+ """return diffusion's denoiser."""
+ _, denoised = self.diffusion.denoise(self.model, x_t, sigma,
+ **self.model_kwargs)
+ if self.clip_denoised:
+ denoised = denoised.clamp(-1, 1)
+ return denoised
+
+ def label_fn(self, class_id):
+ """return random class_id if class_id is none."""
+ assert self.num_classes > 0, \
+ 'If class_cond is not False,' \
+ 'num_classes should be larger than zero.'
+ if class_id:
+ assert -1 < int(class_id) < self.num_classes, \
+ 'If class_cond has been defined as a class_label_id, ' \
+ 'it should be within the range (0,num_classes).'
+ classes = torch.tensor(
+ [int(class_id) for i in range(self.batch_size)],
+ device=device())
+ else:
+ classes = torch.randint(
+ low=0,
+ high=self.num_classes,
+ size=(self.batch_size, ),
+ device=device())
+ return classes
diff --git a/mmagic/models/editors/consistency_models/consistencymodel_modules.py b/mmagic/models/editors/consistency_models/consistencymodel_modules.py
new file mode 100644
index 0000000000..990535ebcc
--- /dev/null
+++ b/mmagic/models/editors/consistency_models/consistencymodel_modules.py
@@ -0,0 +1,761 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from abc import abstractmethod
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from mmagic.registry import MODELS
+from .consistencymodel_utils import (append_dims, avg_pool_nd, checkpoint,
+ conv_nd, convert_module_to_f16,
+ convert_module_to_f32, linear,
+ normalization, timestep_embedding,
+ zero_module)
+
+
+@MODELS.register_module()
+class ConsistencyUNetModel(nn.Module):
+ """The full UNet model with attention and timestep embedding.
+
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for
+ potentially increased efficiency.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = torch.float16 if use_fp16 else torch.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+
+ ch = input_ch = int(channel_mult[0] * model_channels)
+ self.input_blocks = nn.ModuleList([
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, ch, 3, padding=1))
+ ])
+ self._feature_size = ch
+ input_block_chans = [ch]
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=int(mult * model_channels),
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = int(mult * model_channels)
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ use_fp16=use_fp16,
+ ))
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ ) if resblock_updown else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch))
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ use_fp16=use_fp16,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(num_res_blocks + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=int(model_channels * mult),
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = int(model_channels * mult)
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ use_fp16=use_fp16,
+ ))
+ if level and i == num_res_blocks:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ ) if resblock_updown else Upsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch))
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
+ )
+ if use_fp16:
+ self.convert_to_fp16()
+
+ def convert_to_fp16(self):
+ """Convert the torso of the model to float16."""
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """Convert the torso of the model to float32."""
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps, y=None):
+ """Apply the model to an input batch.
+
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), 'must specify y if and only if the model is class-conditional'
+
+ hs = []
+ emb = self.time_embed(
+ timestep_embedding(timesteps, self.model_channels))
+
+ if self.num_classes is not None:
+ assert y.shape == (x.shape[0], )
+ emb = emb + self.label_emb(y)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb)
+ hs.append(h)
+ h = self.middle_block(h, emb)
+ for module in self.output_blocks:
+ h = torch.cat([h, hs.pop()], dim=1)
+ h = module(h, emb)
+ h = h.type(x.dtype)
+ return self.out(h)
+
+
+class AttentionPool2d(nn.Module):
+ """Adapted from CLIP:
+
+ https://github.com/openai/CLIP/blob/main/clip/model.py.
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(
+ torch.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ """Forward function."""
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """Any module where forward() takes timestep embeddings as a second
+ argument."""
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """Apply the module to `x` given `emb` timestep embeddings."""
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """A sequential module that passes timestep embeddings to the children that
+ support it as an extra input."""
+
+ def forward(self, x, emb):
+ """Forward function."""
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """An upsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(
+ dims, self.channels, self.out_channels, 3, padding=1)
+
+ def forward(self, x):
+ """Forward function."""
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
+ mode='nearest')
+ else:
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ """A downsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims,
+ self.channels,
+ self.out_channels,
+ 3,
+ stride=stride,
+ padding=1)
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ """Forward function."""
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """A residual block that can optionally change the number of channels.
+
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels
+ if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(
+ dims, self.out_channels, self.out_channels, 3, padding=1)),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1)
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels,
+ 1)
+
+ def forward(self, x, emb):
+ """Apply the block to a Tensor, conditioned on a timestep embedding.
+
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(self._forward, (x, emb), self.parameters(),
+ self.use_checkpoint)
+
+ def _forward(self, x, emb):
+ """Forward function."""
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """An attention block that allows spatial positions to attend to each
+ other.
+
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ attention_type='legacy',
+ encoder_channels=None,
+ dims=2,
+ use_fp16=False,
+ channels_last=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.use_fp16 = use_fp16
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f'q,k,v channels {channels} is not divisible ' \
+ f'by num_head_channels {num_head_channels}'
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(dims, channels, channels * 3, 1)
+ self.attention_type = attention_type
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads, self.use_fp16)
+
+ self.use_attention_checkpoint = not (self.use_checkpoint
+ or self.attention_type == 'flash')
+ if encoder_channels is not None:
+ assert attention_type != 'flash'
+ self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
+ self.proj_out = zero_module(conv_nd(dims, channels, channels, 1))
+
+ def forward(self, x, encoder_out=None):
+ """Forward function."""
+ if encoder_out is None:
+ return checkpoint(self._forward, (x, ), self.parameters(),
+ self.use_checkpoint)
+ else:
+ return checkpoint(self._forward, (x, encoder_out),
+ self.parameters(), self.use_checkpoint)
+
+ def _forward(self, x, encoder_out=None):
+ """Forward function."""
+ b, _, *spatial = x.shape
+ qkv = self.qkv(self.norm(x)).view(b, -1, np.prod(spatial))
+ if encoder_out is not None:
+ encoder_out = self.encoder_kv(encoder_out)
+ h = checkpoint(self.attention, (qkv, encoder_out), (),
+ self.use_attention_checkpoint)
+ else:
+ h = checkpoint(self.attention, (qkv, ), (),
+ self.use_attention_checkpoint)
+ h = h.view(b, -1, *spatial)
+ h = self.proj_out(h)
+ return x + h
+
+
+def count_flops_attn(model, _x, y):
+ """A counter for the `thop` package to count the operations in an attention
+ operation.
+
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial**2) * c
+ model.total_ops += torch.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+ """A module which performs QKV attention.
+
+ Matches legacy QKVAttention + input/output heads shaping
+ """
+
+ def __init__(self, n_heads, use_fp16=True):
+ super().__init__()
+ self.n_heads = n_heads
+ self.use_fp16 = use_fp16
+ from einops import rearrange
+ self.rearrange = rearrange
+
+ def forward(self, qkv):
+ """Apply QKV attention.
+
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ if self.use_fp16:
+ qkv = qkv.half()
+
+ qkv = self.rearrange(
+ qkv, 'b (three h d) s -> b s three h d', three=3, h=self.n_heads)
+ q, k, v = qkv.transpose(1, 3).transpose(3, 4).split(1, dim=2)
+ q = q.reshape(bs * self.n_heads, ch, length)
+ k = k.reshape(bs * self.n_heads, ch, length)
+ v = v.reshape(bs * self.n_heads, ch, length)
+
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = torch.einsum(
+ 'bct,bcs->bts', q * scale,
+ k * scale) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight, dim=-1).type(weight.dtype)
+ a = torch.einsum('bts,bcs->bct', weight, v)
+ a = a.float()
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ """return count flops attention."""
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """A module which performs QKV attention.
+
+ Fallback from Blocksparse if use_fp16=False
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv, encoder_kv=None):
+ """Apply QKV attention.
+
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ if encoder_kv is not None:
+ assert encoder_kv.shape[1] == 2 * ch * self.n_heads
+ ek, ev = encoder_kv.chunk(2, dim=1)
+ k = torch.cat([ek, k], dim=-1)
+ v = torch.cat([ev, v], dim=-1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = torch.einsum(
+ 'bct,bcs->bts',
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, -1),
+ ) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = torch.einsum('bts,bcs->bct', weight,
+ v.reshape(bs * self.n_heads, ch, -1))
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ """return count flops attention."""
+ return count_flops_attn(model, _x, y)
+
+
+@MODELS.register_module()
+class KarrasDenoiser:
+
+ def __init__(
+ self,
+ sigma_data: float = 0.5,
+ sigma_max=80.0,
+ sigma_min=0.002,
+ rho=7.0,
+ weight_schedule='karras',
+ distillation=False,
+ loss_norm='lpips',
+ ):
+ self.sigma_data = sigma_data
+ self.sigma_max = sigma_max
+ self.sigma_min = sigma_min
+ self.weight_schedule = weight_schedule
+ self.distillation = distillation
+ self.loss_norm = loss_norm
+ self.rho = rho
+ self.num_timesteps = 40
+
+ def get_snr(self, sigmas):
+ """return snr."""
+ return sigmas**-2
+
+ def get_sigmas(self, sigmas):
+ """return sigmas."""
+ return sigmas
+
+ def get_scalings(self, sigma):
+ """return scalings."""
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2)**0.5
+ c_in = 1 / (sigma**2 + self.sigma_data**2)**0.5
+ return c_skip, c_out, c_in
+
+ def get_scalings_for_boundary_condition(self, sigma):
+ """return scalings for boundary condition."""
+ c_skip = self.sigma_data**2 / (
+ (sigma - self.sigma_min)**2 + self.sigma_data**2)
+ c_out = ((sigma - self.sigma_min) * self.sigma_data /
+ (sigma**2 + self.sigma_data**2)**0.5)
+ c_in = 1 / (sigma**2 + self.sigma_data**2)**0.5
+ return c_skip, c_out, c_in
+
+ def denoise(self, model, x_t, sigmas, **model_kwargs):
+ """return model's output and denoise."""
+
+ if not self.distillation:
+ c_skip, c_out, c_in = [
+ append_dims(x, x_t.ndim) for x in self.get_scalings(sigmas)
+ ]
+ else:
+ c_skip, c_out, c_in = [
+ append_dims(x, x_t.ndim)
+ for x in self.get_scalings_for_boundary_condition(sigmas)
+ ]
+ rescaled_t = 1000 * 0.25 * torch.log(sigmas + 1e-44)
+ model_output = model(c_in * x_t, rescaled_t, **model_kwargs)
+ denoised = c_out * model_output + c_skip * x_t
+ return model_output, denoised
diff --git a/mmagic/models/editors/consistency_models/consistencymodel_utils.py b/mmagic/models/editors/consistency_models/consistencymodel_utils.py
new file mode 100644
index 0000000000..73354f8e60
--- /dev/null
+++ b/mmagic/models/editors/consistency_models/consistencymodel_utils.py
@@ -0,0 +1,817 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+
+
+def device():
+ """return torch.device."""
+ if torch.cuda.is_available():
+ return torch.device('cuda')
+ return torch.device('cpu')
+
+
+def get_weightings(weight_schedule, snrs, sigma_data):
+ """return weightings."""
+ if weight_schedule == 'snr':
+ weightings = snrs
+ elif weight_schedule == 'snr+1':
+ weightings = snrs + 1
+ elif weight_schedule == 'karras':
+ weightings = snrs + 1.0 / sigma_data**2
+ elif weight_schedule == 'truncated-snr':
+ weightings = torch.clamp(snrs, min=1.0)
+ elif weight_schedule == 'uniform':
+ weightings = torch.ones_like(snrs)
+ else:
+ raise NotImplementedError()
+ return weightings
+
+
+class SiLU(nn.Module):
+ """PyTorch 1.7 has SiLU, but we support PyTorch 1.5."""
+
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ """PyTorch 1.7 has GroupNorm32, but we support PyTorch 1.5."""
+
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
+def karras_sample(
+ diffusion,
+ model,
+ shape,
+ steps,
+ clip_denoised=True,
+ progress=False,
+ callback=None,
+ model_kwargs=None,
+ device=None,
+ sigma_min=0.002,
+ sigma_max=80, # higher for highres?
+ rho=7.0,
+ sampler='heun',
+ s_churn=0.0,
+ s_tmin=0.0,
+ s_tmax=float('inf'),
+ s_noise=1.0,
+ generator=None,
+ ts=None,
+):
+ """karras sample function."""
+ if generator is None:
+ generator = get_generator('dummy')
+
+ if sampler == 'progdist':
+ sigmas = get_sigmas_karras(
+ steps + 1, sigma_min, sigma_max, rho, device=device)
+ else:
+ sigmas = get_sigmas_karras(
+ steps, sigma_min, sigma_max, rho, device=device)
+
+ x_T = generator.randn(*shape, device=device) * sigma_max
+ sample_fn = get_sample_fn(sampler)
+
+ if sampler in ['heun', 'dpm']:
+ sampler_args = dict(
+ s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise)
+ elif sampler == 'multistep':
+ sampler_args = dict(
+ ts=ts,
+ t_min=sigma_min,
+ t_max=sigma_max,
+ rho=diffusion.rho,
+ steps=steps)
+ else:
+ sampler_args = {}
+
+ def denoiser(x_t, sigma):
+ """denoiser function."""
+ _, denoised = diffusion.denoise(model, x_t, sigma, **model_kwargs)
+ if clip_denoised:
+ denoised = denoised.clamp(-1, 1)
+ return denoised
+
+ x_0 = sample_fn(
+ denoiser,
+ x_T,
+ sigmas,
+ generator,
+ progress=progress,
+ callback=callback,
+ **sampler_args,
+ )
+ return x_0.clamp(-1, 1)
+
+
+def get_sample_fn(sampler):
+ """return sampler function."""
+ return {
+ 'heun': sample_heun,
+ 'dpm': sample_dpm,
+ 'ancestral': sample_euler_ancestral,
+ 'onestep': sample_onestep,
+ 'progdist': sample_progdist,
+ 'euler': sample_euler,
+ 'multistep': stochastic_iterative_sampler,
+ }[sampler]
+
+
+def to_d(x, sigma, denoised):
+ """Converts a denoiser output to a Karras ODE derivative."""
+ return (x - denoised) / append_dims(sigma, x.ndim)
+
+
+def get_ancestral_step(sigma_from, sigma_to):
+ """Calculates the noise level (sigma_down) to step down to and the amount
+ of noise to add (sigma_up) when doing an ancestral sampling step."""
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) /
+ sigma_from**2)**0.5
+ sigma_down = (sigma_to**2 - sigma_up**2)**0.5
+ return sigma_down, sigma_up
+
+
+def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device='cpu'):
+ """Constructs the noise schedule of Karras et al.
+
+ (2022).
+ """
+ ramp = torch.linspace(0, 1, n)
+ min_inv_rho = sigma_min**(1 / rho)
+ max_inv_rho = sigma_max**(1 / rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho
+ return append_zero(sigmas).to(device)
+
+
+@torch.no_grad()
+def sample_euler_ancestral(model,
+ x,
+ sigmas,
+ generator,
+ progress=False,
+ callback=None):
+ """Ancestral sampling with Euler method steps."""
+ s_in = x.new_ones([x.shape[0]])
+ indices = range(len(sigmas) - 1)
+ if progress:
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ denoised = model(x, sigmas[i] * s_in)
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
+ if callback is not None:
+ callback({
+ 'x': x,
+ 'i': i,
+ 'sigma': sigmas[i],
+ 'sigma_hat': sigmas[i],
+ 'denoised': denoised,
+ })
+ d = to_d(x, sigmas[i], denoised)
+ # Euler method
+ dt = sigma_down - sigmas[i]
+ x = x + d * dt
+ x = x + generator.randn_like(x) * sigma_up
+ return x
+
+
+@torch.no_grad()
+def sample_midpoint_ancestral(model,
+ x,
+ ts,
+ generator,
+ progress=False,
+ callback=None):
+ """Ancestral sampling with midpoint method steps."""
+ s_in = x.new_ones([x.shape[0]])
+ step_size = 1 / len(ts)
+ if progress:
+ from tqdm.auto import tqdm
+
+ ts = tqdm(ts)
+
+ for tn in ts:
+ dn = model(x, tn * s_in)
+ dn_2 = model(x + (step_size / 2) * dn, (tn + step_size / 2) * s_in)
+ x = x + step_size * dn_2
+ if callback is not None:
+ callback({'x': x, 'tn': tn, 'dn': dn, 'dn_2': dn_2})
+ return x
+
+
+@torch.no_grad()
+def sample_heun(
+ denoiser,
+ x,
+ sigmas,
+ generator,
+ progress=False,
+ callback=None,
+ s_churn=0.0,
+ s_tmin=0.0,
+ s_tmax=float('inf'),
+ s_noise=1.0,
+):
+ """Implements Algorithm 2 (Heun steps) from Karras et al.
+
+ (2022).
+ """
+ s_in = x.new_ones([x.shape[0]])
+ indices = range(len(sigmas) - 1)
+ if progress:
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ gamma = (
+ min(s_churn / (len(sigmas) - 1), 2**0.5 -
+ 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0)
+ eps = generator.randn_like(x) * s_noise
+ sigma_hat = sigmas[i] * (gamma + 1)
+ if gamma > 0:
+ x = x + eps * (sigma_hat**2 - sigmas[i]**2)**0.5
+ denoised = denoiser(x, sigma_hat * s_in)
+ d = to_d(x, sigma_hat, denoised)
+ if callback is not None:
+ callback({
+ 'x': x,
+ 'i': i,
+ 'sigma': sigmas[i],
+ 'sigma_hat': sigma_hat,
+ 'denoised': denoised,
+ })
+ dt = sigmas[i + 1] - sigma_hat
+ if sigmas[i + 1] == 0:
+ # Euler method
+ x = x + d * dt
+ else:
+ # Heun's method
+ x_2 = x + d * dt
+ denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in)
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
+ d_prime = (d + d_2) / 2
+ x = x + d_prime * dt
+ return x
+
+
+@torch.no_grad()
+def sample_euler(
+ denoiser,
+ x,
+ sigmas,
+ generator,
+ progress=False,
+ callback=None,
+):
+ """Implements Algorithm 2 (Heun steps) from Karras et al.
+
+ (2022).
+ """
+ s_in = x.new_ones([x.shape[0]])
+ indices = range(len(sigmas) - 1)
+ if progress:
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ sigma = sigmas[i]
+ denoised = denoiser(x, sigma * s_in)
+ d = to_d(x, sigma, denoised)
+ if callback is not None:
+ callback({
+ 'x': x,
+ 'i': i,
+ 'sigma': sigmas[i],
+ 'denoised': denoised,
+ })
+ dt = sigmas[i + 1] - sigma
+ x = x + d * dt
+ return x
+
+
+@torch.no_grad()
+def sample_dpm(
+ denoiser,
+ x,
+ sigmas,
+ generator,
+ progress=False,
+ callback=None,
+ s_churn=0.0,
+ s_tmin=0.0,
+ s_tmax=float('inf'),
+ s_noise=1.0,
+):
+ """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al.
+
+ (2022).
+ """
+ s_in = x.new_ones([x.shape[0]])
+ indices = range(len(sigmas) - 1)
+ if progress:
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ gamma = (
+ min(s_churn / (len(sigmas) - 1), 2**0.5 -
+ 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0)
+ eps = generator.randn_like(x) * s_noise
+ sigma_hat = sigmas[i] * (gamma + 1)
+ if gamma > 0:
+ x = x + eps * (sigma_hat**2 - sigmas[i]**2)**0.5
+ denoised = denoiser(x, sigma_hat * s_in)
+ d = to_d(x, sigma_hat, denoised)
+ if callback is not None:
+ callback({
+ 'x': x,
+ 'i': i,
+ 'sigma': sigmas[i],
+ 'sigma_hat': sigma_hat,
+ 'denoised': denoised,
+ })
+ sigma_mid = ((sigma_hat**(1 / 3) + sigmas[i + 1]**(1 / 3)) / 2)**3
+ dt_1 = sigma_mid - sigma_hat
+ dt_2 = sigmas[i + 1] - sigma_hat
+ x_2 = x + d * dt_1
+ denoised_2 = denoiser(x_2, sigma_mid * s_in)
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
+ x = x + d_2 * dt_2
+ return x
+
+
+@torch.no_grad()
+def sample_onestep(
+ distiller,
+ x,
+ sigmas,
+ generator=None,
+ progress=False,
+ callback=None,
+):
+ """Single-step generation from a distilled model."""
+ s_in = x.new_ones([x.shape[0]])
+ return distiller(x, sigmas[0] * s_in)
+
+
+@torch.no_grad()
+def stochastic_iterative_sampler(
+ distiller,
+ x,
+ sigmas,
+ generator,
+ ts,
+ progress=False,
+ callback=None,
+ t_min=0.002,
+ t_max=80.0,
+ rho=7.0,
+ steps=40,
+):
+ """sample function stochastic iterative."""
+ t_max_rho = t_max**(1 / rho)
+ t_min_rho = t_min**(1 / rho)
+ s_in = x.new_ones([x.shape[0]])
+
+ for i in range(len(ts) - 1):
+ t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho))**rho
+ x0 = distiller(x, t * s_in)
+ next_t = (t_max_rho + ts[i + 1] / (steps - 1) *
+ (t_min_rho - t_max_rho))**rho
+ next_t = np.clip(next_t, t_min, t_max)
+ x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2)
+
+ return x
+
+
+@torch.no_grad()
+def sample_progdist(
+ denoiser,
+ x,
+ sigmas,
+ generator=None,
+ progress=False,
+ callback=None,
+):
+ """sample function progdist."""
+ s_in = x.new_ones([x.shape[0]])
+ sigmas = sigmas[:-1] # skip the zero sigma
+
+ indices = range(len(sigmas) - 1)
+ if progress:
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ sigma = sigmas[i]
+ denoised = denoiser(x, sigma * s_in)
+ d = to_d(x, sigma, denoised)
+ if callback is not None:
+ callback({
+ 'x': x,
+ 'i': i,
+ 'sigma': sigma,
+ 'denoised': denoised,
+ })
+ dt = sigmas[i + 1] - sigma
+ x = x + d * dt
+
+ return x
+
+
+def conv_nd(dims, *args, **kwargs):
+ """Create a 1D, 2D, or 3D convolution module."""
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f'unsupported dimensions: {dims}')
+
+
+def linear(*args, **kwargs):
+ """Create a linear module."""
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """Create a 1D, 2D, or 3D average pooling module."""
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f'unsupported dimensions: {dims}')
+
+
+def update_ema(target_params, source_params, rate=0.99):
+ """Update target parameters to be closer to those of source parameters
+ using an exponential moving average.
+
+ :param target_params: the target parameter sequence.
+ :param source_params: the source parameter sequence.
+ :param rate: the EMA rate (closer to 1 means slower).
+ """
+ for targ, src in zip(target_params, source_params):
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
+
+
+def zero_module(module):
+ """Zero out the parameters of a module and return it."""
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """Scale the parameters of a module and return it."""
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """Take the mean over all non-batch dimensions."""
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims
+ dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f'input has {x.ndim} dims '
+ f'but target_dims is {target_dims}, which is less')
+ return x[(..., ) + (None, ) * dims_to_append]
+
+
+def append_zero(x):
+ """add zeors."""
+ return torch.cat([x, x.new_zeros([1])])
+
+
+def normalization(channels):
+ """Make a standard normalization layer.
+
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+ """Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ half = dim // 2
+ freqs = torch.exp(-math.log(max_period) *
+ torch.arange(start=0, end=half, dtype=torch.float32) /
+ half).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding,
+ torch.zeros_like(embedding[:, :1])],
+ dim=-1)
+ return embedding
+
+
+def checkpoint(func, inputs, params, flag):
+ """Evaluate a function without caching intermediate activations, allowing
+ for reduced memory at the expense of extra compute in the backward pass.
+
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ """checkpoint function."""
+
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ """run forward function."""
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ """run backward function."""
+ ctx.input_tensors = [
+ x.detach().requires_grad_(True) for x in ctx.input_tensors
+ ]
+ with torch.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def convert_module_to_f16(l1):
+ """Convert primitive modules to float16."""
+ if isinstance(l1, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+ l1.weight.data = l1.weight.data.half()
+ if l1.bias is not None:
+ l1.bias.data = l1.bias.data.half()
+
+
+def convert_module_to_f32(l2):
+ """Convert primitive modules to float32, undoing
+ convert_module_to_f16()."""
+ if isinstance(l2, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+ l2.weight.data = l2.weight.data.float()
+ if l2.bias is not None:
+ l2.bias.data = l2.bias.data.float()
+
+
+def get_generator(generator, num_samples=0, seed=0):
+ """return generator."""
+ if generator == 'dummy':
+ return DummyGenerator()
+ elif generator == 'determ':
+ return DeterministicGenerator(num_samples, seed)
+ elif generator == 'determ-indiv':
+ return DeterministicIndividualGenerator(num_samples, seed)
+ else:
+ raise NotImplementedError
+
+
+class DummyGenerator:
+ """return Dummy generator."""
+
+ def randn(self, *args, **kwargs):
+ """return random tensor."""
+ return torch.randn(*args, **kwargs)
+
+ def randint(self, *args, **kwargs):
+ """return random int tensor."""
+ return torch.randint(*args, **kwargs)
+
+ def randn_like(self, *args, **kwargs):
+ """return random like tensor."""
+ return torch.randn_like(*args, **kwargs)
+
+
+class DeterministicGenerator:
+ """RNG to deterministically sample num_samples samples that does not depend
+ on batch_size or mpi_machines Uses a single rng and samples num_samples
+ sized randomness and subsamples the current indices."""
+
+ def __init__(self, num_samples, seed=0):
+ if dist.is_initialized():
+ self.rank = dist.get_rank()
+ self.world_size = dist.get_world_size()
+ else:
+ print('Warning: Distributed not initialised, using single rank')
+ self.rank = 0
+ self.world_size = 1
+ self.num_samples = num_samples
+ self.done_samples = 0
+ self.seed = seed
+ self.rng_cpu = torch.Generator()
+ if torch.cuda.is_available():
+ self.rng_cuda = torch.Generator(device())
+ self.set_seed(seed)
+
+ def get_global_size_and_indices(self, size):
+ """return size and indices."""
+ global_size = (self.num_samples, *size[1:])
+ indices = torch.arange(
+ self.done_samples + self.rank,
+ self.done_samples + self.world_size * int(size[0]),
+ self.world_size,
+ )
+ indices = torch.clamp(indices, 0, self.num_samples - 1)
+ assert (
+ len(indices) == size[0]
+ ), f'rank={self.rank}, ws={self.world_size}, ' \
+ f'l={len(indices)}, bs={size[0]}'
+ return global_size, indices
+
+ def get_generator(self, device):
+ """return rng generator."""
+ return self.rng_cpu if torch.device(
+ device).type == 'cpu' else self.rng_cuda
+
+ def randn(self, *size, dtype=torch.float, device='cpu'):
+ """return random tensor."""
+ global_size, indices = self.get_global_size_and_indices(size)
+ generator = self.get_generator(device)
+ return torch.randn(
+ *global_size, generator=generator, dtype=dtype,
+ device=device)[indices]
+
+ def randint(self, low, high, size, dtype=torch.long, device='cpu'):
+ """return random int tensor."""
+ global_size, indices = self.get_global_size_and_indices(size)
+ generator = self.get_generator(device)
+ return torch.randint(
+ low,
+ high,
+ generator=generator,
+ size=global_size,
+ dtype=dtype,
+ device=device)[indices]
+
+ def randn_like(self, tensor):
+ """return random like tensor."""
+ size, dtype, device = tensor.size(), tensor.dtype, tensor.device
+ return self.randn(*size, dtype=dtype, device=device)
+
+ def set_done_samples(self, done_samples):
+ """set model's done_samples."""
+ self.done_samples = done_samples
+ self.set_seed(self.seed)
+
+ def get_seed(self):
+ """return model's seed."""
+ return self.seed
+
+ def set_seed(self, seed):
+ """set model's seed."""
+ self.rng_cpu.manual_seed(seed)
+ if torch.cuda.is_available():
+ self.rng_cuda.manual_seed(seed)
+
+
+class DeterministicIndividualGenerator:
+ """RNG to deterministically sample num_samples samples that does not depend
+ on batch_size or mpi_machines Uses a separate rng for each sample to reduce
+ memory usage."""
+
+ def __init__(self, num_samples, seed=0):
+ if dist.is_initialized():
+ self.rank = dist.get_rank()
+ self.world_size = dist.get_world_size()
+ else:
+ print('Warning: Distributed not initialised, using single rank')
+ self.rank = 0
+ self.world_size = 1
+ self.num_samples = num_samples
+ self.done_samples = 0
+ self.seed = seed
+ self.rng_cpu = [torch.Generator() for _ in range(num_samples)]
+ if torch.cuda.is_available():
+ self.rng_cuda = [
+ torch.Generator(device()) for _ in range(num_samples)
+ ]
+ self.set_seed(seed)
+
+ def get_size_and_indices(self, size):
+ """return size and indices."""
+ indices = torch.arange(
+ self.done_samples + self.rank,
+ self.done_samples + self.world_size * int(size[0]),
+ self.world_size,
+ )
+ indices = torch.clamp(indices, 0, self.num_samples - 1)
+ assert (
+ len(indices) == size[0]
+ ), f'rank={self.rank}, ws={self.world_size}, ' \
+ f'l={len(indices)}, bs={size[0]}'
+ return (1, *size[1:]), indices
+
+ def get_generator(self, device):
+ """return generator."""
+ return self.rng_cpu if torch.device(
+ device).type == 'cpu' else self.rng_cuda
+
+ def randn(self, *size, dtype=torch.float, device='cpu'):
+ """return random generator."""
+ size, indices = self.get_size_and_indices(size)
+ generator = self.get_generator(device)
+ return torch.cat(
+ [
+ torch.randn(
+ *size, generator=generator[i], dtype=dtype, device=device)
+ for i in indices
+ ],
+ dim=0,
+ )
+
+ def randint(self, low, high, size, dtype=torch.long, device='cpu'):
+ """return random int generator."""
+ size, indices = self.get_size_and_indices(size)
+ generator = self.get_generator(device)
+ return torch.cat(
+ [
+ torch.randint(
+ low,
+ high,
+ generator=generator[i],
+ size=size,
+ dtype=dtype,
+ device=device,
+ ) for i in indices
+ ],
+ dim=0,
+ )
+
+ def randn_like(self, tensor):
+ """return random like tensor."""
+ size, dtype, device = tensor.size(), tensor.dtype, tensor.device
+ return self.randn(*size, dtype=dtype, device=device)
+
+ def set_done_samples(self, done_samples):
+ """set model's done_samples."""
+ self.done_samples = done_samples
+
+ def get_seed(self):
+ """return model's seed."""
+ return self.seed
+
+ def set_seed(self, seed):
+ """set model's seed."""
+ [
+ rng_cpu.manual_seed(i + self.num_samples * seed)
+ for i, rng_cpu in enumerate(self.rng_cpu)
+ ]
+ if torch.cuda.is_available():
+ [
+ rng_cuda.manual_seed(i + self.num_samples * seed)
+ for i, rng_cuda in enumerate(self.rng_cuda)
+ ]
diff --git a/tests/test_models/test_editors/test_consistency_models/test_consistency_models.py b/tests/test_models/test_editors/test_consistency_models/test_consistency_models.py
new file mode 100644
index 0000000000..e704e80936
--- /dev/null
+++ b/tests/test_models/test_editors/test_consistency_models/test_consistency_models.py
@@ -0,0 +1,136 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import gc
+import platform
+from copy import deepcopy
+from unittest import TestCase
+
+import pytest
+import torch
+
+from mmagic.models import (ConsistencyModel, ConsistencyUNetModel,
+ DataPreprocessor, KarrasDenoiser)
+from mmagic.registry import MODELS
+from mmagic.utils import register_all_modules
+
+gc.collect()
+torch.cuda.empty_cache()
+register_all_modules()
+denoiser_config = dict(
+ type='KarrasDenoiser',
+ sigma_data=0.5,
+ sigma_max=80.0,
+ sigma_min=0.002,
+ weight_schedule='uniform',
+)
+
+unet_config = dict(
+ type='ConsistencyUNetModel',
+ in_channels=3,
+ model_channels=192,
+ num_res_blocks=3,
+ dropout=0.0,
+ channel_mult='',
+ use_checkpoint=False,
+ use_fp16=False,
+ num_head_channels=64,
+ num_heads=4,
+ num_heads_upsample=-1,
+ resblock_updown=True,
+ use_new_attention_order=False,
+ use_scale_shift_norm=True)
+
+config_onestep = dict(
+ type='ConsistencyModel',
+ unet=unet_config,
+ denoiser=denoiser_config,
+ attention_resolutions='32,16,8',
+ batch_size=4,
+ class_cond=True,
+ generator='determ',
+ image_size=64,
+ learn_sigma=False,
+ model_path=None,
+ num_classes=1000,
+ sampler='onestep',
+ seed=42,
+ training_mode='consistency_distillation',
+ ts='',
+ data_preprocessor=dict(
+ type='DataPreprocessor', mean=[127.5] * 3, std=[127.5] * 3))
+
+config_multistep = dict(
+ type='ConsistencyModel',
+ unet=unet_config,
+ denoiser=denoiser_config,
+ attention_resolutions='32,16,8',
+ batch_size=4,
+ class_cond=True,
+ generator='determ',
+ image_size=64,
+ learn_sigma=False,
+ model_path=None,
+ num_classes=1000,
+ sampler='multistep',
+ seed=42,
+ steps=40,
+ training_mode='consistency_distillation',
+ ts='0,22,39',
+ data_preprocessor=dict(
+ type='DataPreprocessor', mean=[127.5] * 3, std=[127.5] * 3))
+
+
+@pytest.mark.skipif(
+ 'win' in platform.system().lower(),
+ reason='skip on windows due to limited RAM.')
+class TestDeblurGanV2(TestCase):
+
+ def test_init(self):
+ model = ConsistencyModel(
+ unet=unet_config,
+ denoiser=denoiser_config,
+ data_preprocessor=DataPreprocessor())
+ self.assertIsInstance(model, ConsistencyModel)
+ self.assertIsInstance(model.data_preprocessor, DataPreprocessor)
+ self.assertIsInstance(model.model, ConsistencyUNetModel)
+ self.assertIsInstance(model.diffusion, KarrasDenoiser)
+ unet_cfg = deepcopy(unet_config)
+ diffuse_cfg = deepcopy(denoiser_config)
+ unet = MODELS.build(unet_cfg)
+ diffuse = MODELS.build(diffuse_cfg)
+ model = ConsistencyModel(
+ unet=unet, denoiser=diffuse, data_preprocessor=DataPreprocessor())
+ self.assertIsInstance(model.model, ConsistencyUNetModel)
+ self.assertIsInstance(model.diffusion, KarrasDenoiser)
+
+ def test_onestep_infer(self):
+ model = MODELS.build(config_onestep)
+ data = {
+ 'num_batches': model.batch_size,
+ 'labels': None,
+ 'sample_model': 'orig'
+ }
+ result = model(data)
+ assert len(result) == model.batch_size
+ for datasample in result:
+ assert datasample.fake_img.shape == (3, model.image_size,
+ model.image_size)
+
+ def test_multistep_infer(self):
+ model = MODELS.build(config_multistep)
+ data = {
+ 'num_batches': model.batch_size,
+ 'labels': None,
+ 'sample_model': 'orig'
+ }
+ result = model(data)
+ assert len(result) == model.batch_size
+ for datasample in result:
+ assert datasample.fake_img.shape == (3, model.image_size,
+ model.image_size)
+
+
+def teardown_module():
+ import gc
+ gc.collect()
+ globals().clear()
+ locals().clear()
From dd68b0b6a76d9669f7d52820e128bdc4f719a874 Mon Sep 17 00:00:00 2001
From: xiaomile <15622388695@163.com>
Date: Tue, 12 Dec 2023 14:59:16 +0800
Subject: [PATCH 2/9] rerun ci check
rerun ci check
---
.../editors/consistency_models/consistencymodel_modules.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/mmagic/models/editors/consistency_models/consistencymodel_modules.py b/mmagic/models/editors/consistency_models/consistencymodel_modules.py
index 990535ebcc..486f372990 100644
--- a/mmagic/models/editors/consistency_models/consistencymodel_modules.py
+++ b/mmagic/models/editors/consistency_models/consistencymodel_modules.py
@@ -88,7 +88,6 @@ def __init__(
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
-
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
From b374173fc43e2e9da955976bdad9a61614ebb030 Mon Sep 17 00:00:00 2001
From: xiaomile <15622388695@163.com>
Date: Tue, 12 Dec 2023 15:37:07 +0800
Subject: [PATCH 3/9] Update model-index.yml
---
model-index.yml | 1 +
1 file changed, 1 insertion(+)
diff --git a/model-index.yml b/model-index.yml
index cdb2cd1fc9..7cfa8f686e 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -5,6 +5,7 @@ Import:
- configs/basicvsr_pp/metafile.yml
- configs/biggan/metafile.yml
- configs/cain/metafile.yml
+- configs/consistency_models/metafile.yml
- configs/controlnet/metafile.yml
- configs/controlnet_animation/metafile.yml
- configs/cyclegan/metafile.yml
From f72fdad1bbd9f15e738a00d01c287eb8a4ad7059 Mon Sep 17 00:00:00 2001
From: xiaomile <15622388695@163.com>
Date: Wed, 13 Dec 2023 11:09:03 +0800
Subject: [PATCH 4/9] delete not used code
---
.../consistency_models/consistencymodel.py | 19 +++++++-----
.../consistencymodel_utils.py | 29 ++-----------------
.../test_consistency_models.py | 15 +++++++++-
3 files changed, 28 insertions(+), 35 deletions(-)
diff --git a/mmagic/models/editors/consistency_models/consistencymodel.py b/mmagic/models/editors/consistency_models/consistencymodel.py
index eb31aefa19..94b407d412 100644
--- a/mmagic/models/editors/consistency_models/consistencymodel.py
+++ b/mmagic/models/editors/consistency_models/consistencymodel.py
@@ -13,7 +13,7 @@
from mmagic.registry import MODELS
from mmagic.structures import DataSample
from mmagic.utils import ForwardInputs
-from .consistencymodel_utils import (device, get_generator, get_sample_fn,
+from .consistencymodel_utils import (get_generator, get_sample_fn,
get_sigmas_karras, karras_sample)
ModelType = Union[Dict, nn.Module]
@@ -56,6 +56,9 @@ def __init__(self,
super().__init__(data_preprocessor=data_preprocessor)
self.num_classes = num_classes
+ self.device = torch.device('cpu')
+ if torch.cuda.is_available():
+ self.device = torch.device('cuda')
if 'consistency' in training_mode:
self.distillation = True
else:
@@ -117,7 +120,7 @@ def __init__(self,
self.model.load_state_dict(
torch.load(model_path, map_location='cpu'))
- self.model.to(device())
+ self.model.to(self.device)
if sampler == 'multistep':
assert len(ts) > 0
@@ -147,7 +150,7 @@ def infer(self, class_id: Optional[int] = None):
(self.batch_size, 3, self.image_size, self.image_size),
steps=self.steps,
model_kwargs=self.model_kwargs,
- device=device(),
+ device=self.device,
clip_denoised=self.clip_denoised,
sampler=self.sampler,
sigma_min=self.sigma_min,
@@ -216,18 +219,18 @@ def forward(self,
self.sigma_min,
self.sigma_max,
self.diffusion.rho,
- device=device())
+ device=self.device)
else:
sigmas = get_sigmas_karras(
self.steps,
self.sigma_min,
self.sigma_max,
self.diffusion.rho,
- device=device())
+ device=self.device)
noise = self.generator.randn(
*(self.batch_size, 3, self.image_size, self.image_size),
- device=device()) * self.sigma_max
+ device=self.device) * self.sigma_max
sample_fn = get_sample_fn(self.sampler)
@@ -291,11 +294,11 @@ def label_fn(self, class_id):
'it should be within the range (0,num_classes).'
classes = torch.tensor(
[int(class_id) for i in range(self.batch_size)],
- device=device())
+ device=self.device)
else:
classes = torch.randint(
low=0,
high=self.num_classes,
size=(self.batch_size, ),
- device=device())
+ device=self.device)
return classes
diff --git a/mmagic/models/editors/consistency_models/consistencymodel_utils.py b/mmagic/models/editors/consistency_models/consistencymodel_utils.py
index 73354f8e60..ba9e2e5670 100644
--- a/mmagic/models/editors/consistency_models/consistencymodel_utils.py
+++ b/mmagic/models/editors/consistency_models/consistencymodel_utils.py
@@ -7,30 +7,6 @@
import torch.nn as nn
-def device():
- """return torch.device."""
- if torch.cuda.is_available():
- return torch.device('cuda')
- return torch.device('cpu')
-
-
-def get_weightings(weight_schedule, snrs, sigma_data):
- """return weightings."""
- if weight_schedule == 'snr':
- weightings = snrs
- elif weight_schedule == 'snr+1':
- weightings = snrs + 1
- elif weight_schedule == 'karras':
- weightings = snrs + 1.0 / sigma_data**2
- elif weight_schedule == 'truncated-snr':
- weightings = torch.clamp(snrs, min=1.0)
- elif weight_schedule == 'uniform':
- weightings = torch.ones_like(snrs)
- else:
- raise NotImplementedError()
- return weightings
-
-
class SiLU(nn.Module):
"""PyTorch 1.7 has SiLU, but we support PyTorch 1.5."""
@@ -654,7 +630,7 @@ def __init__(self, num_samples, seed=0):
self.seed = seed
self.rng_cpu = torch.Generator()
if torch.cuda.is_available():
- self.rng_cuda = torch.Generator(device())
+ self.rng_cuda = torch.Generator(torch.device('cuda'))
self.set_seed(seed)
def get_global_size_and_indices(self, size):
@@ -737,7 +713,8 @@ def __init__(self, num_samples, seed=0):
self.rng_cpu = [torch.Generator() for _ in range(num_samples)]
if torch.cuda.is_available():
self.rng_cuda = [
- torch.Generator(device()) for _ in range(num_samples)
+ torch.Generator(torch.device('cuda'))
+ for _ in range(num_samples)
]
self.set_seed(seed)
diff --git a/tests/test_models/test_editors/test_consistency_models/test_consistency_models.py b/tests/test_models/test_editors/test_consistency_models/test_consistency_models.py
index e704e80936..408c6d9800 100644
--- a/tests/test_models/test_editors/test_consistency_models/test_consistency_models.py
+++ b/tests/test_models/test_editors/test_consistency_models/test_consistency_models.py
@@ -82,17 +82,20 @@
@pytest.mark.skipif(
'win' in platform.system().lower(),
reason='skip on windows due to limited RAM.')
-class TestDeblurGanV2(TestCase):
+class TestConsistencyModels(TestCase):
def test_init(self):
model = ConsistencyModel(
unet=unet_config,
denoiser=denoiser_config,
data_preprocessor=DataPreprocessor())
+ if torch.cuda.is_available():
+ self.assertIsInstance(model.device, torch.device('cuda'))
self.assertIsInstance(model, ConsistencyModel)
self.assertIsInstance(model.data_preprocessor, DataPreprocessor)
self.assertIsInstance(model.model, ConsistencyUNetModel)
self.assertIsInstance(model.diffusion, KarrasDenoiser)
+
unet_cfg = deepcopy(unet_config)
diffuse_cfg = deepcopy(denoiser_config)
unet = MODELS.build(unet_cfg)
@@ -114,6 +117,11 @@ def test_onestep_infer(self):
for datasample in result:
assert datasample.fake_img.shape == (3, model.image_size,
model.image_size)
+ result, labels = model.infer()
+ assert len(result) == model.batch_size
+ assert len(labels) == model.batch_size
+ for datasample in result:
+ assert datasample.shape == (model.image_size, model.image_size, 3)
def test_multistep_infer(self):
model = MODELS.build(config_multistep)
@@ -127,6 +135,11 @@ def test_multistep_infer(self):
for datasample in result:
assert datasample.fake_img.shape == (3, model.image_size,
model.image_size)
+ result, labels = model.infer()
+ assert len(result) == model.batch_size
+ assert len(labels) == model.batch_size
+ for datasample in result:
+ assert datasample.shape == (model.image_size, model.image_size, 3)
def teardown_module():
From 45d9a1e0b9337f524e55726a0dc9a1a9fea51536 Mon Sep 17 00:00:00 2001
From: xiaomile <15622388695@163.com>
Date: Wed, 13 Dec 2023 15:29:53 +0800
Subject: [PATCH 5/9] Create test_consistency_model_utils.py
---
.../test_consistency_model_utils.py | 97 +++++++++++++++++++
1 file changed, 97 insertions(+)
create mode 100644 tests/test_models/test_editors/test_consistency_models/test_consistency_model_utils.py
diff --git a/tests/test_models/test_editors/test_consistency_models/test_consistency_model_utils.py b/tests/test_models/test_editors/test_consistency_models/test_consistency_model_utils.py
new file mode 100644
index 0000000000..95df5c2812
--- /dev/null
+++ b/tests/test_models/test_editors/test_consistency_models/test_consistency_model_utils.py
@@ -0,0 +1,97 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import gc
+import platform
+from copy import deepcopy
+from unittest import TestCase
+
+import pytest
+import torch
+
+from mmagic.models.editors.consistency_models.consistencymodel_utils import (
+ DeterministicGenerator, DeterministicIndividualGenerator, DummyGenerator,
+ get_generator, get_sample_fn, karras_sample, sample_dpm, sample_euler,
+ sample_euler_ancestral, sample_heun, sample_onestep, sample_progdist,
+ stochastic_iterative_sampler)
+from mmagic.registry import MODELS
+from mmagic.utils import register_all_modules
+
+gc.collect()
+torch.cuda.empty_cache()
+register_all_modules()
+denoiser_config = dict(
+ type='KarrasDenoiser',
+ sigma_data=0.5,
+ sigma_max=80.0,
+ sigma_min=0.002,
+ weight_schedule='uniform',
+)
+
+unet_config = dict(
+ type='ConsistencyUNetModel',
+ image_size=64,
+ out_channels=3,
+ attention_resolutions=(2, 4, 8),
+ in_channels=3,
+ model_channels=192,
+ num_res_blocks=3,
+ dropout=0.0,
+ channel_mult=(1, 2, 3, 4),
+ use_checkpoint=False,
+ use_fp16=False,
+ num_head_channels=64,
+ num_heads=4,
+ num_heads_upsample=-1,
+ resblock_updown=True,
+ use_new_attention_order=False,
+ use_scale_shift_norm=True)
+
+
+@pytest.mark.skipif(
+ 'win' in platform.system().lower(),
+ reason='skip on windows due to limited RAM.')
+class TestConsistencyModelUtils(TestCase):
+
+ def test_karras_sample(self):
+ unet_cfg = deepcopy(unet_config)
+ diffuse_cfg = deepcopy(denoiser_config)
+ unet = MODELS.build(unet_cfg)
+ diffuse = MODELS.build(diffuse_cfg)
+ image_size = 64
+ channel_num = 3
+ steps = 2
+ batch_size = 4
+ model_kwargs = {}
+ sample = karras_sample(
+ diffuse,
+ unet, (batch_size, channel_num, image_size, image_size),
+ steps=steps,
+ model_kwargs=model_kwargs)
+ assert sample.shape == (batch_size, channel_num, image_size,
+ image_size)
+
+ def test_get_generator(self):
+ self.assertIsInstance(get_generator('dummy'), DummyGenerator)
+ self.assertIsInstance(get_generator('determ'), DeterministicGenerator)
+ self.assertIsInstance(
+ get_generator('determ-indiv'), DeterministicIndividualGenerator)
+ with pytest.raises(NotImplementedError):
+ get_generator('')
+
+ def test_sample_fn(self):
+ self.assertEqual(get_sample_fn('heun'), sample_heun)
+ self.assertEqual(get_sample_fn('dpm'), sample_dpm)
+ self.assertEqual(get_sample_fn('ancestral'), sample_euler_ancestral)
+ self.assertEqual(get_sample_fn('onestep'), sample_onestep)
+ self.assertEqual(get_sample_fn('progdist'), sample_progdist)
+ self.assertEqual(get_sample_fn('euler'), sample_euler)
+ self.assertEqual(
+ get_sample_fn('multistep'), stochastic_iterative_sampler)
+ with pytest.raises(KeyError):
+ get_sample_fn('')
+
+
+def teardown_module():
+ import gc
+ gc.collect()
+ globals().clear()
+ locals().clear()
From 1bb079628e9c39afad50b66daecbf5bc000459c5 Mon Sep 17 00:00:00 2001
From: xiaomile <15622388695@163.com>
Date: Wed, 13 Dec 2023 16:47:40 +0800
Subject: [PATCH 6/9] Update test_consistency_model_utils.py
---
.../test_consistency_model_utils.py | 65 +++++++++++++++++--
1 file changed, 61 insertions(+), 4 deletions(-)
diff --git a/tests/test_models/test_editors/test_consistency_models/test_consistency_model_utils.py b/tests/test_models/test_editors/test_consistency_models/test_consistency_model_utils.py
index 95df5c2812..79d671e9fb 100644
--- a/tests/test_models/test_editors/test_consistency_models/test_consistency_model_utils.py
+++ b/tests/test_models/test_editors/test_consistency_models/test_consistency_model_utils.py
@@ -9,9 +9,9 @@
from mmagic.models.editors.consistency_models.consistencymodel_utils import (
DeterministicGenerator, DeterministicIndividualGenerator, DummyGenerator,
- get_generator, get_sample_fn, karras_sample, sample_dpm, sample_euler,
- sample_euler_ancestral, sample_heun, sample_onestep, sample_progdist,
- stochastic_iterative_sampler)
+ get_generator, get_sample_fn, get_sigmas_karras, karras_sample, sample_dpm,
+ sample_euler, sample_euler_ancestral, sample_heun, sample_onestep,
+ sample_progdist, stochastic_iterative_sampler)
from mmagic.registry import MODELS
from mmagic.utils import register_all_modules
@@ -82,13 +82,70 @@ def test_sample_fn(self):
self.assertEqual(get_sample_fn('dpm'), sample_dpm)
self.assertEqual(get_sample_fn('ancestral'), sample_euler_ancestral)
self.assertEqual(get_sample_fn('onestep'), sample_onestep)
- self.assertEqual(get_sample_fn('progdist'), sample_progdist)
self.assertEqual(get_sample_fn('euler'), sample_euler)
self.assertEqual(
get_sample_fn('multistep'), stochastic_iterative_sampler)
+ self.assertEqual(get_sample_fn('progdist'), sample_progdist)
with pytest.raises(KeyError):
get_sample_fn('')
+ clip_denoised = True
+ unet_cfg = deepcopy(unet_config)
+ diffuse_cfg = deepcopy(denoiser_config)
+ unet = MODELS.build(unet_cfg)
+ diffuse = MODELS.build(diffuse_cfg)
+ model_kwargs = {}
+ device = torch.device('cpu')
+ sigma_max = 80
+ sigma_min = 0.002
+ s_churn = 0.0
+ s_tmin = 0.0
+ s_tmax = float('inf')
+ s_noise = 1.0
+ ts = (1, 2)
+ shape = (4, 3, 64, 64)
+ generator = get_generator('dummy')
+ x_T = generator.randn(*shape, device=device) * sigma_max
+
+ def denoiser(x_t, sigma):
+ """denoiser function."""
+ _, denoised = diffuse.denoise(unet, x_t, sigma, **model_kwargs)
+ if clip_denoised:
+ denoised = denoised.clamp(-1, 1)
+ return denoised
+
+ sample_list = [
+ 'heun', 'dpm', 'ancestral', 'onestep', 'progdist', 'euler',
+ 'multistep'
+ ]
+ for sample in sample_list:
+ if sample == 'progdist':
+ sigmas = get_sigmas_karras(
+ 2 + 1, sigma_min, sigma_max, 7.0, device=device)
+ else:
+ sigmas = get_sigmas_karras(
+ 2, sigma_min, sigma_max, 7.0, device=device)
+ if sample in ['heun', 'dpm']:
+ sampler_args = dict(
+ s_churn=s_churn,
+ s_tmin=s_tmin,
+ s_tmax=s_tmax,
+ s_noise=s_noise)
+ elif sample == 'multistep':
+ sampler_args = dict(
+ ts=ts, t_min=sigma_min, t_max=sigma_max, rho=7.0, steps=2)
+ else:
+ sampler_args = {}
+ assert get_sample_fn(sample)(
+ denoiser,
+ x_T,
+ sigmas,
+ generator,
+ progress=False,
+ callback=None,
+ **sampler_args,
+ ).clamp(-1, 1).shape == shape
+
def teardown_module():
import gc
From 05fe595be2306c371c2fe6b758bc8b352f111b10 Mon Sep 17 00:00:00 2001
From: xiaomile <15622388695@163.com>
Date: Wed, 13 Dec 2023 17:47:22 +0800
Subject: [PATCH 7/9] add more test
---
.../consistencymodel_utils.py | 1 +
.../test_consistency_model_utils.py | 79 +++++++++++--------
2 files changed, 45 insertions(+), 35 deletions(-)
diff --git a/mmagic/models/editors/consistency_models/consistencymodel_utils.py b/mmagic/models/editors/consistency_models/consistencymodel_utils.py
index ba9e2e5670..78dc1d0318 100644
--- a/mmagic/models/editors/consistency_models/consistencymodel_utils.py
+++ b/mmagic/models/editors/consistency_models/consistencymodel_utils.py
@@ -98,6 +98,7 @@ def get_sample_fn(sampler):
'progdist': sample_progdist,
'euler': sample_euler,
'multistep': stochastic_iterative_sampler,
+ 'midpoint': sample_midpoint_ancestral,
}[sampler]
diff --git a/tests/test_models/test_editors/test_consistency_models/test_consistency_model_utils.py b/tests/test_models/test_editors/test_consistency_models/test_consistency_model_utils.py
index 79d671e9fb..2c822f3e6e 100644
--- a/tests/test_models/test_editors/test_consistency_models/test_consistency_model_utils.py
+++ b/tests/test_models/test_editors/test_consistency_models/test_consistency_model_utils.py
@@ -10,8 +10,9 @@
from mmagic.models.editors.consistency_models.consistencymodel_utils import (
DeterministicGenerator, DeterministicIndividualGenerator, DummyGenerator,
get_generator, get_sample_fn, get_sigmas_karras, karras_sample, sample_dpm,
- sample_euler, sample_euler_ancestral, sample_heun, sample_onestep,
- sample_progdist, stochastic_iterative_sampler)
+ sample_euler, sample_euler_ancestral, sample_heun,
+ sample_midpoint_ancestral, sample_onestep, sample_progdist,
+ stochastic_iterative_sampler)
from mmagic.registry import MODELS
from mmagic.utils import register_all_modules
@@ -83,12 +84,14 @@ def test_sample_fn(self):
self.assertEqual(get_sample_fn('ancestral'), sample_euler_ancestral)
self.assertEqual(get_sample_fn('onestep'), sample_onestep)
self.assertEqual(get_sample_fn('euler'), sample_euler)
+ self.assertEqual(get_sample_fn('midpoint'), sample_midpoint_ancestral)
self.assertEqual(
get_sample_fn('multistep'), stochastic_iterative_sampler)
self.assertEqual(get_sample_fn('progdist'), sample_progdist)
with pytest.raises(KeyError):
get_sample_fn('')
+ def test_sample_fn_and_get_generator(self):
clip_denoised = True
unet_cfg = deepcopy(unet_config)
diffuse_cfg = deepcopy(denoiser_config)
@@ -104,8 +107,10 @@ def test_sample_fn(self):
s_noise = 1.0
ts = (1, 2)
shape = (4, 3, 64, 64)
- generator = get_generator('dummy')
- x_T = generator.randn(*shape, device=device) * sigma_max
+ sample_list = [
+ 'heun', 'dpm', 'ancestral', 'onestep', 'progdist', 'euler',
+ 'multistep', 'midpoint'
+ ]
def denoiser(x_t, sigma):
"""denoiser function."""
@@ -114,37 +119,41 @@ def denoiser(x_t, sigma):
denoised = denoised.clamp(-1, 1)
return denoised
- sample_list = [
- 'heun', 'dpm', 'ancestral', 'onestep', 'progdist', 'euler',
- 'multistep'
- ]
- for sample in sample_list:
- if sample == 'progdist':
- sigmas = get_sigmas_karras(
- 2 + 1, sigma_min, sigma_max, 7.0, device=device)
- else:
- sigmas = get_sigmas_karras(
- 2, sigma_min, sigma_max, 7.0, device=device)
- if sample in ['heun', 'dpm']:
- sampler_args = dict(
- s_churn=s_churn,
- s_tmin=s_tmin,
- s_tmax=s_tmax,
- s_noise=s_noise)
- elif sample == 'multistep':
- sampler_args = dict(
- ts=ts, t_min=sigma_min, t_max=sigma_max, rho=7.0, steps=2)
- else:
- sampler_args = {}
- assert get_sample_fn(sample)(
- denoiser,
- x_T,
- sigmas,
- generator,
- progress=False,
- callback=None,
- **sampler_args,
- ).clamp(-1, 1).shape == shape
+ generator_list = ['dummy', 'determ', 'determ-indiv']
+ for generator_str in generator_list:
+ generator = get_generator(generator_str, 4, 0)
+ x_T = generator.randn(*shape, device=device) * sigma_max
+ for sample in sample_list:
+ if sample == 'progdist':
+ sigmas = get_sigmas_karras(
+ 2 + 1, sigma_min, sigma_max, 7.0, device=device)
+ else:
+ sigmas = get_sigmas_karras(
+ 2, sigma_min, sigma_max, 7.0, device=device)
+ if sample in ['heun', 'dpm']:
+ sampler_args = dict(
+ s_churn=s_churn,
+ s_tmin=s_tmin,
+ s_tmax=s_tmax,
+ s_noise=s_noise)
+ elif sample == 'multistep':
+ sampler_args = dict(
+ ts=ts,
+ t_min=sigma_min,
+ t_max=sigma_max,
+ rho=7.0,
+ steps=2)
+ else:
+ sampler_args = {}
+ assert get_sample_fn(sample)(
+ denoiser,
+ x_T,
+ sigmas,
+ generator,
+ progress=False,
+ callback=None,
+ **sampler_args,
+ ).clamp(-1, 1).shape == shape
def teardown_module():
From f68dc9987a7bc9842573294d56cb47f406beef1e Mon Sep 17 00:00:00 2001
From: xiaomile <15622388695@163.com>
Date: Wed, 13 Dec 2023 19:00:16 +0800
Subject: [PATCH 8/9] add more test
---
.../editors/consistency_models/consistencymodel_modules.py | 4 ++--
.../test_consistency_models/test_consistency_model_utils.py | 3 +++
2 files changed, 5 insertions(+), 2 deletions(-)
diff --git a/mmagic/models/editors/consistency_models/consistencymodel_modules.py b/mmagic/models/editors/consistency_models/consistencymodel_modules.py
index 486f372990..70376cdd9b 100644
--- a/mmagic/models/editors/consistency_models/consistencymodel_modules.py
+++ b/mmagic/models/editors/consistency_models/consistencymodel_modules.py
@@ -235,13 +235,13 @@ def __init__(
self.convert_to_fp16()
def convert_to_fp16(self):
- """Convert the torso of the model to float16."""
+ """Convert the tensor of the model to float16."""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
self.output_blocks.apply(convert_module_to_f16)
def convert_to_fp32(self):
- """Convert the torso of the model to float32."""
+ """Convert the tensor of the model to float32."""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
diff --git a/tests/test_models/test_editors/test_consistency_models/test_consistency_model_utils.py b/tests/test_models/test_editors/test_consistency_models/test_consistency_model_utils.py
index 2c822f3e6e..400ec5edce 100644
--- a/tests/test_models/test_editors/test_consistency_models/test_consistency_model_utils.py
+++ b/tests/test_models/test_editors/test_consistency_models/test_consistency_model_utils.py
@@ -69,6 +69,8 @@ def test_karras_sample(self):
model_kwargs=model_kwargs)
assert sample.shape == (batch_size, channel_num, image_size,
image_size)
+ unet.convert_to_fp32()
+ unet.convert_to_fp16()
def test_get_generator(self):
self.assertIsInstance(get_generator('dummy'), DummyGenerator)
@@ -122,6 +124,7 @@ def denoiser(x_t, sigma):
generator_list = ['dummy', 'determ', 'determ-indiv']
for generator_str in generator_list:
generator = get_generator(generator_str, 4, 0)
+ generator.randint(1, 2, (1, 2), dtype=torch.long, device='cpu')
x_T = generator.randn(*shape, device=device) * sigma_max
for sample in sample_list:
if sample == 'progdist':
From 015e5d47424d6a6fa77d17c5424126716b7041e9 Mon Sep 17 00:00:00 2001
From: xiaomile <15622388695@163.com>
Date: Thu, 14 Dec 2023 10:10:50 +0800
Subject: [PATCH 9/9] Update test_consistency_models.py
rerun ci check
---
.../test_consistency_models/test_consistency_models.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/tests/test_models/test_editors/test_consistency_models/test_consistency_models.py b/tests/test_models/test_editors/test_consistency_models/test_consistency_models.py
index 408c6d9800..d455d939cc 100644
--- a/tests/test_models/test_editors/test_consistency_models/test_consistency_models.py
+++ b/tests/test_models/test_editors/test_consistency_models/test_consistency_models.py
@@ -95,7 +95,6 @@ def test_init(self):
self.assertIsInstance(model.data_preprocessor, DataPreprocessor)
self.assertIsInstance(model.model, ConsistencyUNetModel)
self.assertIsInstance(model.diffusion, KarrasDenoiser)
-
unet_cfg = deepcopy(unet_config)
diffuse_cfg = deepcopy(denoiser_config)
unet = MODELS.build(unet_cfg)