-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
154 lines (124 loc) · 5.95 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch.nn as nn
import math
import torch
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
class Block(nn.Module):
def __init__(self, channels_in, channels_out, time_embedding_dims, labels, num_filters = 3, downsample=True):
super().__init__()
self.time_embedding_dims = time_embedding_dims
self.time_embedding = SinusoidalPositionEmbeddings(time_embedding_dims)
self.labels = labels
if labels:
self.label_mlp = nn.Linear(1, channels_out)
self.downsample = downsample
if downsample:
self.conv1 = nn.Conv2d(channels_in, channels_out, num_filters, padding=1)
self.final = nn.Conv2d(channels_out, channels_out, 4, 2, 1)
else:
self.conv1 = nn.Conv2d(2 * channels_in, channels_out, num_filters, padding=1)
self.final = nn.ConvTranspose2d(channels_out, channels_out, 4, 2, 1)
self.bnorm1 = nn.BatchNorm2d(channels_out)
self.bnorm2 = nn.BatchNorm2d(channels_out)
self.conv2 = nn.Conv2d(channels_out, channels_out, 3, padding=1)
self.time_mlp = nn.Linear(time_embedding_dims, channels_out)
self.relu = nn.ReLU()
def forward(self, x, t, **kwargs):
o = self.bnorm1(self.relu(self.conv1(x)))
o_time = self.relu(self.time_mlp(self.time_embedding(t)))
o = o + o_time[(..., ) + (None, ) * 2]
if self.labels:
label = kwargs.get('labels')
o_label = self.relu(self.label_mlp(label))
o = o + o_label[(..., ) + (None, ) * 2]
o = self.bnorm2(self.relu(self.conv2(o)))
return self.final(o)
class UNet(nn.Module):
def __init__(self, img_channels = 3, time_embedding_dims = 128, labels = False, sequence_channels = (64, 128, 256, 512, 1024)):
super().__init__()
self.time_embedding_dims = time_embedding_dims
sequence_channels_rev = reversed(sequence_channels)
self.downsampling = nn.ModuleList([Block(channels_in, channels_out, time_embedding_dims, labels) for channels_in, channels_out in zip(sequence_channels, sequence_channels[1:])])
self.upsampling = nn.ModuleList([Block(channels_in, channels_out, time_embedding_dims, labels,downsample=False) for channels_in, channels_out in zip(sequence_channels[::-1], sequence_channels[::-1][1:])])
self.conv1 = nn.Conv2d(img_channels, sequence_channels[0], 3, padding=1)
self.conv2 = nn.Conv2d(sequence_channels[0], img_channels, 1)
def forward(self, x, t, **kwargs):
residuals = []
o = self.conv1(x)
for ds in self.downsampling:
o = ds(o, t, **kwargs)
residuals.append(o)
for us, res in zip(self.upsampling, reversed(residuals)):
o = us(torch.cat((o, res), dim=1), t, **kwargs)
return self.conv2(o)
class DiffusionModel:
def __init__(self, start_schedule=0.0001, end_schedule=0.02, timesteps = 300):
self.start_schedule = start_schedule
self.end_schedule = end_schedule
self.timesteps = timesteps
"""
if
betas = [0.1, 0.2, 0.3, ...]
then
alphas = [0.9, 0.8, 0.7, ...]
alphas_cumprod = [0.9, 0.9 * 0.8, 0.9 * 0.8, * 0.7, ...]
"""
self.betas = torch.linspace(start_schedule, end_schedule, timesteps)
self.alphas = 1 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
def forward(self, x_0, t, device):
"""
x_0: (B, C, H, W)
t: (B,)
"""
noise = torch.randn_like(x_0)
sqrt_alphas_cumprod_t = self.get_index_from_list(self.alphas_cumprod.sqrt(), t, x_0.shape)
sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(torch.sqrt(1. - self.alphas_cumprod), t, x_0.shape)
mean = sqrt_alphas_cumprod_t.to(device) * x_0.to(device)
variance = sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device)
return mean + variance, noise.to(device)
@torch.no_grad()
def backward(self, x, t, model, **kwargs):
"""
Calls the model to predict the noise in the image and returns
the denoised image.
Applies noise to this image, if we are not in the last step yet.
"""
betas_t = self.get_index_from_list(self.betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(torch.sqrt(1. - self.alphas_cumprod), t, x.shape)
sqrt_recip_alphas_t = self.get_index_from_list(torch.sqrt(1.0 / self.alphas), t, x.shape)
mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t, **kwargs) / sqrt_one_minus_alphas_cumprod_t)
posterior_variance_t = betas_t
if t == 0:
return mean
else:
noise = torch.randn_like(x)
variance = torch.sqrt(posterior_variance_t) * noise
return mean + variance
@staticmethod
def get_index_from_list(values, t, x_shape):
batch_size = t.shape[0]
"""
pick the values from vals
according to the indices stored in `t`
"""
result = values.gather(-1, t.cpu())
"""
if
x_shape = (5, 3, 64, 64)
-> len(x_shape) = 4
-> len(x_shape) - 1 = 3
and thus we reshape `out` to dims
(batch_size, 1, 1, 1)
"""
return result.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)