-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathedm.py
executable file
·183 lines (152 loc) · 7.25 KB
/
edm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import torch
# from torch_utils import persistence
import numpy as np
# from craftsman.utils.base import BaseModule
# @persistence.persistent_class
class EDM2Loss:
def __init__(self, P_mean=-0.4, P_std=1.0, sigma_data=0.5):
self.P_mean = P_mean
self.P_std = P_std
self.sigma_data = sigma_data
def __call__(self, net, images, labels):
# print('images shape: {} labels shape: {}'.format(images.shape, labels.shape))
rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
sigma = (rnd_normal * self.P_std + self.P_mean).exp()
weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
noise = torch.randn_like(images) * sigma
# import ipdb; ipdb.set_trace(context = 10)
denoised, logvar = net(images + noise, sigma, labels, return_logvar=True)
loss = (weight / logvar.exp()) * ((denoised - images) ** 2) + logvar
return loss
def normalize(x, dim=None, eps=1e-4):
if dim is None:
dim = list(range(1, x.ndim))
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
return x / norm.to(x.dtype)
#----------------------------------------------------------------------------
# EDM sampler from the paper
# "Elucidating the Design Space of Diffusion-Based Generative Models",
# extended to support classifier-free guidance.
#same net
def edm_sampler(
net, noise, labels,
num_steps=32, sigma_min=0.002, sigma_max=80, rho=7, guidance=1,
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
# S_churn=40, S_min=0, S_max=float('inf'), S_noise=1,
dtype=torch.float32, randn_like=torch.randn_like,
):
# Guided denoiser.
if not guidance == 1:
uncond, cond = labels.chunk(2)
def denoise(x, t):
if guidance == 1:
Dx = net(x, t, labels).to(dtype)
return Dx
else:
Dx = net(x, t, cond).to(dtype)
ref_Dx = net(x, t, uncond).to(dtype)
return ref_Dx.lerp(Dx, guidance)
# Time step discretization.
step_indices = torch.arange(num_steps, dtype=dtype, device=noise.device)
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
# Main sampling loop.
x_next = noise.to(dtype) * t_steps[0]
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
x_cur = x_next
# Increase noise temporarily.
if S_churn > 0 and S_min <= t_cur <= S_max:
gamma = min(S_churn / num_steps, np.sqrt(2) - 1)
t_hat = t_cur + gamma * t_cur
x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
else:
t_hat = t_cur
x_hat = x_cur
# Euler step.
d_cur = (x_hat - denoise(x_hat, t_hat)) / t_hat
x_next = x_hat + (t_next - t_hat) * d_cur
# Apply 2nd order correction.
if i < num_steps - 1:
d_prime = (x_next - denoise(x_next, t_next)) / t_next
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
return x_next
#----------------------------------------------------------------------------
# Magnitude-preserving Fourier features (Equation 75).
# @persistence.persistent_class
class MPFourier(torch.nn.Module):
def __init__(self, num_channels, bandwidth=1):
super().__init__()
self.register_buffer('freqs', 2 * np.pi * torch.randn(num_channels) * bandwidth)
self.register_buffer('phases', 2 * np.pi * torch.rand(num_channels))
# self.freqs = self.freqs.to('cuda')
# self.phases = self.phases.to('cuda')
def forward(self, x):
y = x.to(torch.float32)
# import pdb; pdb.set_trace()
y = y.ger(self.freqs.to(torch.float32))
y = y + self.phases.to(torch.float32)
y = y.cos() * np.sqrt(2)
return y.to(x.dtype)
#----------------------------------------------------------------------------
# Magnitude-preserving convolution or fully-connected layer (Equation 47)
# with force weight normalization (Equation 66).
# @persistence.persistent_class
class MPConv(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel):
super().__init__()
self.out_channels = out_channels
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
def forward(self, x, gain=1):
w = self.weight.to(torch.float32)
if self.training:
with torch.no_grad():
self.weight.copy_(normalize(w)) # forced weight normalization
w = normalize(w) # traditional weight normalization
w = w * (gain / np.sqrt(w[0].numel())) # magnitude-preserving scaling
w = w.to(x.dtype)
if w.ndim == 2:
# print('fc version')
return x @ w.t()
assert w.ndim == 4
return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1]//2,))
# @persistence.persistent_class
class EDMPrecond(torch.nn.Module):
# class EDMPrecond(BaseModule):
def __init__(self,
denoiser_model,
# label_dim, # Class label dimensionality. 0 = unconditional.
sigma_data = 0.5, # Expected standard deviation of the training data.
logvar_channels = 128, # Intermediate dimensionality for uncertainty estimation.
):
super().__init__()
# self.label_dim = label_dim
# self.use_fp16 = use_fp16
self.sigma_data = sigma_data
self.denoiser_model = denoiser_model
self.logvar_fourier = MPFourier(logvar_channels)
self.logvar_linear = MPConv(logvar_channels, 1, kernel=[])
def forward(self, x, sigma, class_labels, return_logvar=False):
x = x.to(torch.float32)
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
# class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
# dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
# class_labels = class_labels.to(torch.float32)
dtype = x.dtype
# Preconditioning weights.
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
c_noise = sigma.flatten().log() / 4
# Run the model.
x_in = (c_in * x).to(dtype)
# print('xin {} cnoise {} class_label {}'.format(x_in.shape, c_noise.shape, class_labels.shape))
F_x = self.denoiser_model(x_in, c_noise, class_labels)
D_x = c_skip * x + c_out * F_x.to(torch.float32)
# import ipdb; ipdb.set_trace(context = 10)
# Estimate uncertainty if requested.
if return_logvar:
logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1)
return D_x, logvar # u(sigma) in Equation 21
return D_x
#----------------------------------------------------------------------------