-
Notifications
You must be signed in to change notification settings - Fork 0
/
sampling.py
137 lines (108 loc) · 4.93 KB
/
sampling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
from dotmap import DotMap
from ncsnv2.models.ema import EMAHelper
import torch, sys, os, json, argparse
sys.path.append('.')
# Args
parser = argparse.ArgumentParser()
parser.add_argument('--config_path', type=str)
args = DotMap(json.load(open(parser.parse_args().config_path)))
from tqdm import tqdm as tqdm
from ncsnv2.models.ncsnv2 import NCSNv2Deepest, NCSNv2Deeper, NCSNv2
from annealedLangevin import ald
from utils import *
# Always !!!
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
# GPU
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID";
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.model.gpu);
# Target weights - replace with target model
contents = torch.load(args.sampling.target_model)
# Extract config
config = contents['config']
config.model.depth = args.model.depth
config.sampling = args.sampling
config.sampling.sigma = 0.
config.model.num_classes = args.model.num_classes
config.sampling.num_steps = config.model.num_classes - config.sampling.sigma_offset
# Range of SNR, test channels and hyper-parameters
config.sampling.noise_range = 10 ** (-torch.tensor(config.sampling.snr_range) / 10.) * config.sampling.oracle_shape[-1]
# Get a model
if config.model.depth == 'large':
diffuser = NCSNv2Deepest(config)
elif config.model.depth == 'medium':
diffuser = NCSNv2Deeper(config)
elif config.model.depth == 'low':
diffuser = NCSNv2(config)
if len(config.training.sigmas) > 0:
diffuser.sigmas = config.training.sigmas
config.training.sigmas = diffuser.sigmas
diffuser = diffuser.cuda()
# !!! Load weights
try:
contents['ema_state']['sigmas'] = config.training.sigmas
diffuser.load_state_dict(contents['ema_state'])
except:
diffuser.load_state_dict(contents['model_state'])
print('Failed to load EMA, defaulting to Model State')
diffuser.eval()
if not config.sampling.step_size or config.sampling.prior_sampling == 1:
# Choose the core step size (epsilon) according to [Song '20]
candidate_steps = np.logspace(-11, -7, 10000)
step_criterion = np.zeros((len(candidate_steps)))
gamma_rate = 1 / config.model.sigma_rate
for idx, step in enumerate(candidate_steps):
sigma_squared = config.model.sigma_end ** 2
one_minus_ratio = (1 - step / sigma_squared) ** 2
big_ratio = 2 * step /\
(sigma_squared - sigma_squared * one_minus_ratio)
# Criterion
step_criterion[idx] = one_minus_ratio ** config.sampling.steps_each * \
(gamma_rate ** 2 - big_ratio) + big_ratio
best_idx = np.argmin(np.abs(step_criterion - 1.))
fixed_step_size = candidate_steps[best_idx]
config.sampling.step_size = torch.tensor(fixed_step_size)
config.data = args.data
config.training = args.training
# Global results
result_dir = './results/' + config.data.file + '_' + config.data.dataloader + '/' + config.sampling.target_model.split("/")[-2]
if not os.path.isdir(result_dir):
os.makedirs(result_dir)
forward_model = None
Y_adj, oracle, forward_operator, adjoint_operator = None, None, None, None
print('Dataset: ' + config.data.file)
print('Dataloader: ' + config.data.dataloader)
print('\nStep Size: ' + str(np.float64(config.sampling.step_size)) + '\n')
best_images = []
if config.sampling.prior_sampling == 0:
print('Forward Class: ' + config.sampling.forward_class)
forward_model = globals()[config.sampling.forward_class]()
Y_adj, oracle, forward_operator, adjoint_operator = forward_model.DataLoader(config)
config.sampling.channels = oracle.shape[0]
init_val_X = torch.randn_like(oracle).cuda()
else:
real = torch.randn(config.sampling.channels, config.sampling.oracle_shape[0], config.sampling.oracle_shape[1], dtype = torch.float)
imag = torch.randn(config.sampling.channels, config.sampling.oracle_shape[0], config.sampling.oracle_shape[1], dtype = torch.float)
init_val_X = torch.complex(real, imag).cuda()
config.sampling.noise_range = [1]
config.sampling.noise_boost = 1
oracle = init_val_X.clone()
Y_adj = torch.zeros(len(config.sampling.noise_range))
config.sampling.sampling_file = 'prior'
# For each SNR value
for snr_idx, config.sampling.local_noise in tqdm(enumerate(config.sampling.noise_range)):
# Starting with random noise and running Annealed Langevin Dynamics
print('\n\nSampling for SNR Level ' + str(snr_idx) + ': ' + str(config.sampling.snr_range[snr_idx]))
current = init_val_X.clone()
best_images.append(ald(diffuser, config, Y_adj[snr_idx], oracle, current, forward_operator, adjoint_operator))
torch.cuda.empty_cache()
# Save results to file based on noise
save_dict = {'snr_range': config.sampling.snr_range,
'config': config,
'oracle_H': oracle,
'best_images': best_images}
torch.save(save_dict, result_dir + '/' + config.sampling.sampling_file + '.pt')