forked from FutureXiang/Diffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sample_guide.py
130 lines (114 loc) · 4.43 KB
/
sample_guide.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import argparse
import os
import random
import numpy as np
import torch
import torch.distributed as dist
import yaml
from torchvision.utils import make_grid, save_image
from ema_pytorch import EMA
from model.models import get_models_class
from utils import Config, init_seeds, gather_tensor
def get_default_steps(model_type, steps):
if steps is not None:
return steps
else:
return {'DDPM': 100, 'EDM': 18}[model_type]
# ===== sampling =====
def sample(opt):
print(opt)
yaml_path = opt.config
local_rank = opt.local_rank
use_amp = opt.use_amp
mode = opt.mode
steps = opt.steps
eta = opt.eta
batches = opt.batches
use_ema = opt.ema
ep = opt.epoch
w = opt.w
with open(yaml_path, 'r') as f:
opt = yaml.full_load(f)
print(opt)
opt = Config(opt)
if ep == -1:
ep = opt.n_epoch - 1
device = "cuda:%d" % local_rank
steps = get_default_steps(opt.model_type, steps)
DIFFUSION, NETWORK = get_models_class(opt.model_type, opt.net_type, guide=True)
diff = DIFFUSION(nn_model=NETWORK(**opt.network),
**opt.diffusion,
device=device,
drop_prob=0.1)
diff.to(device)
target = os.path.join(opt.save_dir, "ckpts", f"model_{ep}.pth")
print("loading model at", target)
checkpoint = torch.load(target, map_location=device)
if use_ema:
ema = EMA(diff, beta=0.9999, update_after_step=1000, update_every=10)
ema.to(device)
ema.load_state_dict(checkpoint['EMA'])
model = ema.ema_model
prefix = "EMA"
else:
diff = torch.nn.SyncBatchNorm.convert_sync_batchnorm(diff)
diff = torch.nn.parallel.DistributedDataParallel(
diff, device_ids=[local_rank], output_device=local_rank)
diff.load_state_dict(checkpoint['MODEL'])
model = diff.module
prefix = ""
model.eval()
if local_rank == 0:
if opt.model_type == 'EDM':
gen_dir = os.path.join(opt.save_dir, f"{prefix}generated_ep{ep}_w{w}_edm_steps{steps}_eta{eta}")
else:
if mode == 'DDPM':
gen_dir = os.path.join(opt.save_dir, f"{prefix}generated_ep{ep}_w{w}_ddpm")
else:
gen_dir = os.path.join(opt.save_dir, f"{prefix}generated_ep{ep}_w{w}_ddim_steps{steps}_eta{eta}")
os.makedirs(gen_dir)
gen_dir_png = os.path.join(gen_dir, "pngs")
os.makedirs(gen_dir_png)
res = []
for batch in range(batches):
with torch.no_grad():
assert 400 % dist.get_world_size() == 0
samples_per_process = 400 // dist.get_world_size()
args = dict(n_sample=samples_per_process, size=opt.network['image_shape'], guide_w=w, notqdm=(local_rank != 0), use_amp=use_amp)
if opt.model_type == 'EDM':
x_gen = model.edm_sample(**args, num_steps=steps, eta=eta)
else:
if mode == 'DDPM':
x_gen = model.sample(**args)
else:
x_gen = model.ddim_sample(**args, steps=steps, eta=eta)
dist.barrier()
x_gen = gather_tensor(x_gen)
if local_rank == 0:
res.append(x_gen)
grid = make_grid(x_gen.cpu(), nrow=20)
png_path = os.path.join(gen_dir, f"grid_{batch}.png")
save_image(grid, png_path)
if local_rank == 0:
res = torch.cat(res)
for no, img in enumerate(res):
png_path = os.path.join(gen_dir_png, f"{no}.png")
save_image(img, png_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str)
parser.add_argument('--local_rank', default=-1, type=int,
help='node rank for distributed training')
parser.add_argument("--use_amp", action='store_true', default=False)
parser.add_argument("--mode", type=str, choices=['DDPM', 'DDIM'], default='DDIM')
parser.add_argument("--steps", type=int, default=None)
parser.add_argument("--eta", type=float, default=0.0)
parser.add_argument("--batches", type=int, default=125)
parser.add_argument("--ema", action='store_true', default=False)
parser.add_argument("--epoch", type=int, default=-1)
parser.add_argument("--w", type=float, default=0.3)
opt = parser.parse_args()
init_seeds(no=opt.local_rank)
dist.init_process_group(backend='nccl')
torch.cuda.set_device(opt.local_rank)
sample(opt)