-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
323 lines (255 loc) · 10.9 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
# Reference: https://github.com/ziwei-jiang/PGGAN-PyTorch/blob/master/model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
LEAKINESS = 0.2
# "We use leaky ReLU with leakiness 0.2 in all layers of both networks, except for the last layer
# that uses linear activation."
GAIN = 0.2
class EqualLRLinear(nn.Module):
"""
"EQUALIZED LEARNING RATE: We use a trivial $\mathcal{N}(0, 1)$ initialization and then
explicitly scale the weights at runtime. We set $w^{^}_{i} = w_{i} / c$, where $w_{i}$ are the weights
and $c$ is the per-layer normalization constant from He’s initializer."
"We initialize all bias parameters to zero and all weights according to the normal distribution
with unit variance. However, we scale the weights with a layer-specific constant at runtime."
"The idea is to scale the parameters of each layer just before every forward propagation
that passes through. How much to scale by is determined by a statistic calculated
from the parameter values of each layer."
"""
def __init__(self, in_features, out_features, gain=GAIN):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.gain = gain
self.scale = np.sqrt(gain / in_features)
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
self.bias = nn.Parameter(torch.Tensor(out_features))
nn.init.normal_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
x = F.linear(x, weight=self.weight * self.scale, bias=self.bias)
return x
class EqualLRConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, gain=GAIN):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.gain = gain
self.scale = (gain / (in_channels * kernel_size * kernel_size)) ** 0.5
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size, kernel_size))
self.bias = nn.Parameter(torch.Tensor(out_channels))
nn.init.normal_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
x = F.conv2d(x, weight=self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
return x
class DownsampleBlock(nn.Module):
def __init__(self, in_channels, out_channels, downsample=True, leakiness=LEAKINESS):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.downsample = downsample
self.leakiness = leakiness
if downsample:
self.conv1 = EqualLRConv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.conv2 = EqualLRConv2d(out_channels, out_channels, kernel_size=3, padding=1)
else:
self.conv1 = EqualLRConv2d(in_channels + 1, out_channels, kernel_size=3, padding=1)
self.conv2 = EqualLRConv2d(out_channels, out_channels, kernel_size=4)
self.proj = EqualLRLinear(out_channels, 1)
def add_minibatch_std(self, x):
b, _, h, w = x.shape
# "We compute the standard deviation for each feature in each spatial location over the minibatch.
# We then average these estimates over all features and spatial locations to arrive at a single value.
# We replicate the value and concatenate it to all spatial locations and over the minibatch,
# yielding one additional (constant) feature map."
feat_map = x.std(dim=0, keepdim=True).mean(dim=(1, 2, 3), keepdim=True)
x = torch.cat([x, feat_map.repeat(b, 1, h, w)], dim=1)
return x
def forward(self, x):
if not self.downsample:
# "We inject the across-minibatch standard deviation as an additional feature map
# at 4×4 resolution toward the end of the discriminator."
x = self.add_minibatch_std(x)
x = self.conv1(x)
x = F.leaky_relu(x, negative_slope=self.leakiness)
x = self.conv2(x)
x = F.leaky_relu(x, negative_slope=self.leakiness)
if self.downsample:
x = _half(x)
else:
x = x.view(-1, self.out_channels)
x = self.proj(x)
x = x.view(-1, 1, 1, 1)
return x
class FromRGB(nn.Module):
"""
"The `fromRGB` does the reverse of `toRGB`. it uses 1×1 convolutions."
"""
def __init__(self, out_channels, leakiness=LEAKINESS):
super().__init__()
self.leakiness = leakiness
self.out_channels = out_channels
self.conv = EqualLRConv2d(3, out_channels, kernel_size=1)
def forward(self, x):
x = self.conv(x)
x = F.leaky_relu(x, negative_slope=self.leakiness)
return x
def _half(x):
"""
"'0.5×' refer to halving the image resolution using nearest neighbor average pooling."
"""
return F.avg_pool2d(x, kernel_size=2, stride=2)
class Discriminator(nn.Module): # 25,444,737 parameters in total.
def __init__(self):
super().__init__()
self.block1 = DownsampleBlock(512, 512, downsample=False)
self.block2 = DownsampleBlock(512, 512)
self.block7 = DownsampleBlock(64, 128)
self.block4 = DownsampleBlock(512, 512)
self.block5 = DownsampleBlock(256, 512)
self.block6 = DownsampleBlock(128, 256)
self.block8 = DownsampleBlock(32, 64)
self.block3 = DownsampleBlock(512, 512)
self.block9 = DownsampleBlock(16, 32)
self.from_rgb1 = FromRGB(512)
self.from_rgb2 = FromRGB(512)
self.from_rgb3 = FromRGB(512)
self.from_rgb4 = FromRGB(512)
self.from_rgb5 = FromRGB(256)
self.from_rgb6 = FromRGB(128)
self.from_rgb7 = FromRGB(64)
self.from_rgb8 = FromRGB(32)
self.from_rgb9 = FromRGB(16)
def forward(self, x, img_size, alpha):
if img_size == 4:
x = self.from_rgb1(x)
x = self.block1(x)
else:
depth = _get_depth(img_size)
skip = x.clone()
skip = _half(skip)
skip = eval(f"""self.from_rgb{depth - 1}""")(skip)
x = eval(f"""self.from_rgb{depth}""")(x)
x = eval(f"""self.block{depth}""")(x)
x = (1 - alpha) * skip + alpha * x
for d in range(depth - 1, 0, -1):
x = eval(f"""self.block{d}""")(x)
return x
class UpsampleBlock(nn.Module):
def __init__(self, in_channels, out_channels, upsample=True, leakiness=LEAKINESS):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.upsample = upsample
self.leakiness = leakiness
if upsample:
self.conv1 = EqualLRConv2d(in_channels, out_channels, kernel_size=3, padding=1)
else:
self.conv1 = EqualLRConv2d(in_channels, out_channels, kernel_size=4, padding=3)
self.conv2 = EqualLRConv2d(out_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x):
### ORDER OF LAYERS!
if self.upsample:
x = _double(x)
x = self.conv1(x)
x = F.leaky_relu(x, negative_slope=self.leakiness)
# "We perform pixel-wise normalization of the feature vectors after each Conv 3×3 layer
# in the generator."
x = perform_pixel_norm(x)
x = self.conv2(x)
x = F.leaky_relu(x, negative_slope=self.leakiness)
x = perform_pixel_norm(x)
return x
class ToRGB(nn.Module):
"""
"The `toRGB` represents a layer that projects feature vectors to RGB colors. It uses 1×1 convolutions."
"""
def __init__(self, in_channels, leakiness=LEAKINESS):
super().__init__()
self.leakiness = leakiness
self.in_channels = in_channels
self.conv = EqualLRConv2d(in_channels, 3, kernel_size=1)
def forward(self, x):
x = self.conv(x)
return x
def _get_depth(img_size):
depth = int(math.log2(img_size)) - 1
return depth
def perform_pixel_norm(x, eps=1e-8):
"""
"PIXELWISE FEATURE VECTOR NORMALIZATION IN GENERATOR: We normalize the feature vector in each pixel
to unit length in the generator after each convolutional layer."
"$b_{x, y} = a_{x, y} / \sqrt{1 / N \sum^{N - 1}_{j=0}(a^{j}_{x, y})^{2} + \epsilon}$, where
$\epsilon = 10^{-8}$, $N$ is the number of feature maps, and $a_{x, y}$ and $b_{x, y}$ are
the original and normalized feature vector in pixel $(x, y)$, respectively."
"""
x = x / torch.sqrt((x ** 2).mean(dim=1, keepdim=True)+ eps)
return x
def _double(x):
"""
"'2×' refer to doubling the image resolution using nearest neighbor filtering."
"""
return F.interpolate(x, scale_factor=2, mode="nearest")
class Generator(nn.Module):
"""
23,079,115 ("23.1M") parameters in total.
"""
def __init__(self):
super().__init__()
self.block1 = UpsampleBlock(512, 512, upsample=False)
self.block2 = UpsampleBlock(512, 512)
self.block3 = UpsampleBlock(512, 512)
self.block4 = UpsampleBlock(512, 512)
self.block5 = UpsampleBlock(512, 256)
self.block6 = UpsampleBlock(256, 128)
self.block7 = UpsampleBlock(128, 64)
self.block8 = UpsampleBlock(64, 32)
self.block9 = UpsampleBlock(32, 16)
# "The last Conv 1×1 layer of the generator corresponds to the 'toRGB' block."
self.to_rgb1 = ToRGB(512)
self.to_rgb2 = ToRGB(512)
self.to_rgb3 = ToRGB(512)
self.to_rgb4 = ToRGB(512)
self.to_rgb5 = ToRGB(256)
self.to_rgb6 = ToRGB(128)
self.to_rgb7 = ToRGB(64)
self.to_rgb8 = ToRGB(32)
self.to_rgb9 = ToRGB(16)
def forward(self, x, img_size, alpha):
if img_size == 4:
x = self.block1(x)
x = self.to_rgb1(x)
else:
depth = _get_depth(img_size)
for d in range(1, depth):
x = eval(f"""self.block{d}""")(x)
skip = x.clone()
skip = _double(skip)
skip = eval(f"""self.to_rgb{depth - 1}""")(skip)
x = eval(f"""self.block{depth}""")(x)
x = eval(f"""self.to_rgb{depth}""")(x)
x = (1 - alpha) * skip + alpha * x
return x
if __name__ == "__main__":
from utils import print_number_of_parameters
BATCH_SIZE = 2
# for img_size in [4, 8, 16, 32, 64, 128, 256, 512, 1024]:
img_size = 1024
alpha = 0.5
gen = Generator()
print_number_of_parameters(gen)
x = torch.randn(BATCH_SIZE, 512, 1, 1)
out = gen(x, img_size=img_size, alpha=alpha)
print(out.shape)
disc = Discriminator()
print_number_of_parameters(gen)
x = torch.randn((BATCH_SIZE, 3, img_size, img_size))
out = disc(x, img_size=img_size, alpha=alpha)
print(out.shape)