Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support denoising demo #202

Merged
merged 10 commits into from
Jan 11, 2022
155 changes: 155 additions & 0 deletions demo/ddpm_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import argparse
import os
import os.path as osp
import sys

import mmcv
import numpy as np
import torch
from mmcv import DictAction
from torchvision import utils

# yapf: disable
sys.path.append(os.path.abspath(os.path.join(__file__, '../..'))) # isort:skip # noqa

from mmgen.apis import init_model, sample_ddpm_model # isort:skip # noqa
# yapf: enable


def parse_args():
parser = argparse.ArgumentParser(description='DDPM demo')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--save-path',
type=str,
default='./work_dirs/demos/ddpm_samples.png',
help='path to save uncoditional samples')
parser.add_argument(
'--device', type=str, default='cuda:0', help='CUDA device id')

# args for inference/sampling
parser.add_argument(
'--num-batches', type=int, default=4, help='Batch size in inference')
parser.add_argument(
'--num-samples',
type=int,
default=12,
help='The total number of samples')
parser.add_argument(
'--sample-model',
type=str,
default='ema',
help='Which model to use for sampling')
parser.add_argument(
'--sample-cfg',
nargs='+',
action=DictAction,
help='Other customized kwargs for sampling function')
parser.add_argument(
'--same-noise',
action='store_true',
help='whether use same noise as input (x_T)')
parser.add_argument(
'--n-skip',
type=int,
default=25,
help=('Skip how many steps before selecting one to visualize. This is '
'helpful with denoising timestep is too much. Only work with '
'`save-path` is end with \'.gif\'.'))

# args for image grid
parser.add_argument(
'--padding', type=int, default=0, help='Padding in the image grid.')
parser.add_argument(
'--nrow',
type=int,
default=2,
help=('Number of images displayed in each row of the grid. '
'This argument would work only when label is not given.'))

# args for image channel order
parser.add_argument(
'--is-rgb',
action='store_true',
help=('If true, color channels will not be permuted, This option is '
'useful when inference model trained with rgb images.'))

args = parser.parse_args()
return args


def create_gif(results, gif_name, fps=60, n_skip=1):
"""Create gif through imageio.

Args:
frames (torch.Tensor): Image frames, shape like [bz, 3, H, W].
gif_name (str): Saved gif name.
fps (int, optional): Frames per second of the generated gif.
Defaults to 60.
n_skip (int, optional): Skip how many steps before selecting one to
visualize. Defaults to 1.
"""
try:
import imageio
except ImportError:
raise RuntimeError('imageio is not installed,'
'Please use “pip install imageio” to install')
frames_list = []
for frame in results[::n_skip]:
frames_list.append(
(frame.permute(1, 2, 0).cpu().numpy() * 255.).astype(np.uint8))

# ensure the final denoising results in frames_list
if not (len(results) % n_skip == 0):
frames_list.append((results[-1].permute(1, 2, 0).cpu().numpy() *
255.).astype(np.uint8))

imageio.mimsave(gif_name, frames_list, 'GIF', fps=fps)


def main():
args = parse_args()
model = init_model(
args.config, checkpoint=args.checkpoint, device=args.device)

if args.sample_cfg is None:
args.sample_cfg = dict()

suffix = osp.splitext(args.save_path)[-1]
if suffix == '.gif':
args.sample_cfg['save_intermedia'] = True

results = sample_ddpm_model(model, args.num_samples, args.num_batches,
args.sample_model, args.same_noise,
**args.sample_cfg)

# save images
mmcv.mkdir_or_exist(os.path.dirname(args.save_path))
if suffix == '.gif':
# concentrate all output of each timestep
results_timestep_list = []
for t in results.keys():
# make grid
results_timestep = utils.make_grid(
results[t], nrow=args.nrow, padding=args.padding)
# unsqueeze at 0, because make grid output is size like [H', W', 3]
results_timestep_list.append(results_timestep[None, ...])

# Concatenates to [n_timesteps, H', W', 3]
results_timestep = torch.cat(results_timestep_list, dim=0)
if not args.is_rgb:
results_timestep = results_timestep[:, [2, 1, 0]]
results_timestep = (results_timestep + 1.) / 2.
create_gif(results_timestep, args.save_path, n_skip=args.n_skip)
else:
if not args.is_rgb:
results = results[:, [2, 1, 0]]

results = (results + 1.) / 2.
utils.save_image(
results, args.save_path, nrow=args.nrow, padding=args.padding)


if __name__ == '__main__':
main()
6 changes: 4 additions & 2 deletions mmgen/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .inference import (init_model, sample_conditional_model,
sample_img2img_model, sample_uncoditional_model)
sample_ddpm_model, sample_img2img_model,
sample_uncoditional_model)
from .train import set_random_seed, train_model

__all__ = [
'set_random_seed', 'train_model', 'init_model', 'sample_img2img_model',
'sample_uncoditional_model', 'sample_conditional_model'
'sample_uncoditional_model', 'sample_conditional_model',
'sample_ddpm_model'
]
68 changes: 68 additions & 0 deletions mmgen/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,71 @@ def sample_img2img_model(model, image_path, target_domain=None, **kwargs):
**kwargs)
output = results['target']
return output


@torch.no_grad()
def sample_ddpm_model(model,
num_samples=16,
num_batches=4,
sample_model='ema',
same_noise=False,
**kwargs):
"""Sampling from ddpm models.

Args:
model (nn.Module): DDPM models in MMGeneration.
num_samples (int, optional): The total number of samples.
Defaults to 16.
num_batches (int, optional): The number of batch size for inference.
Defaults to 4.
sample_model (str, optional): Which model you want to use. ['ema',
'orig']. Defaults to 'ema'.
noise_batch (torch.Tensor): Noise batch used as denoising starting up.
Defaults to None.

Returns:
list[Tensor | dict]: Generated image tensor.
"""
model.eval()

n_repeat = num_samples // num_batches
batches_list = [num_batches] * n_repeat

if num_samples % num_batches > 0:
batches_list.append(num_samples % num_batches)

noise_batch = torch.randn(model.image_shape) if same_noise else None

res_list = []
# inference
for idx, batches in enumerate(batches_list):
mmcv.print_log(
f'Start to sample batch [{idx+1} / '
f'{len(batches_list)}]', 'mmgen')
noise_batch_ = noise_batch[None, ...].expand(batches, -1, -1, -1) \
if same_noise else None

res = model.sample_from_noise(
noise_batch_,
num_batches=batches,
sample_model=sample_model,
show_pbar=True,
**kwargs)
if isinstance(res, dict):
res = {k: v.cpu() for k, v in res.items()}
elif isinstance(res, torch.Tensor):
res = res.cpu()
else:
raise ValueError('Sample results should be \'dict\' or '
f'\'torch.Tensor\', but receive \'{type(res)}\'')
res_list.append(res)

# gather the res_list
if isinstance(res_list[0], dict):
res_dict = dict()
for t in res_list[0].keys():
# num_samples x 3 x H x W
res_dict[t] = torch.cat([res[t] for res in res_list], dim=0)
return res_dict
else:
return torch.cat(res_list, dim=0)
51 changes: 50 additions & 1 deletion tests/test_apis/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import torch

from mmgen.apis import (init_model, sample_img2img_model,
from mmgen.apis import (init_model, sample_ddpm_model, sample_img2img_model,
sample_uncoditional_model)


Expand Down Expand Up @@ -78,3 +78,52 @@ def test_translation_model_cuda(self):
res = sample_img2img_model(
self.cyclegan.cuda(), self.img_path, target_domain='photo')
assert res.shape == (1, 3, 256, 256)


class TestDiffusionModel:

@classmethod
def setup_class(cls):
project_dir = os.path.abspath(os.path.join(__file__, '../../..'))
ddpm_config = mmcv.Config.fromfile(
os.path.join(
project_dir, 'configs/improved_ddpm/'
'ddpm_cosine_hybird_timestep-4k_drop0.3_'
'cifar10_32x32_b8x16_500k.py'))
# change timesteps to speed up test process
ddpm_config.model['num_timesteps'] = 10
cls.model = init_model(ddpm_config, checkpoint=None, device='cpu')

def test_diffusion_model_cpu(self):
# save_intermedia is False
res = sample_ddpm_model(
self.model, num_samples=3, num_batches=2, same_noise=True)
assert res.shape == (3, 3, 32, 32)

# save_intermedia is True
res = sample_ddpm_model(
self.model,
num_samples=2,
num_batches=2,
same_noise=True,
save_intermedia=True)
assert isinstance(res, dict)
assert all([i in res for i in range(10)])

@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_diffusion_model_cuda(self):
model = self.model.cuda()
# save_intermedia is False
res = sample_ddpm_model(
model, num_samples=3, num_batches=2, same_noise=True)
assert res.shape == (3, 3, 32, 32)

# save_intermedia is True
res = sample_ddpm_model(
model,
num_samples=2,
num_batches=2,
same_noise=True,
save_intermedia=True)
assert isinstance(res, dict)
assert all([i in res for i in range(10)])