-
Notifications
You must be signed in to change notification settings - Fork 191
/
gan.py
339 lines (279 loc) · 10.7 KB
/
gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
'''
An example of distribution approximation using Generative Adversarial Networks
in TensorFlow.
Based on the blog post by Eric Jang:
http://blog.evjang.com/2016/06/generative-adversarial-nets-in.html,
and of course the original GAN paper by Ian Goodfellow et. al.:
https://arxiv.org/abs/1406.2661.
The minibatch discrimination technique is taken from Tim Salimans et. al.:
https://arxiv.org/abs/1606.03498.
'''
import argparse
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from matplotlib import animation
import seaborn as sns
sns.set(color_codes=True)
seed = 42
np.random.seed(seed)
tf.set_random_seed(seed)
class DataDistribution(object):
def __init__(self):
self.mu = 4
self.sigma = 0.5
def sample(self, N):
samples = np.random.normal(self.mu, self.sigma, N)
samples.sort()
return samples
class GeneratorDistribution(object):
def __init__(self, range):
self.range = range
def sample(self, N):
return np.linspace(-self.range, self.range, N) + \
np.random.random(N) * 0.01
def linear(input, output_dim, scope=None, stddev=1.0):
with tf.variable_scope(scope or 'linear'):
w = tf.get_variable(
'w',
[input.get_shape()[1], output_dim],
initializer=tf.random_normal_initializer(stddev=stddev)
)
b = tf.get_variable(
'b',
[output_dim],
initializer=tf.constant_initializer(0.0)
)
return tf.matmul(input, w) + b
def generator(input, h_dim):
h0 = tf.nn.softplus(linear(input, h_dim, 'g0'))
h1 = linear(h0, 1, 'g1')
return h1
def discriminator(input, h_dim, minibatch_layer=True):
h0 = tf.nn.relu(linear(input, h_dim * 2, 'd0'))
h1 = tf.nn.relu(linear(h0, h_dim * 2, 'd1'))
# without the minibatch layer, the discriminator needs an additional layer
# to have enough capacity to separate the two distributions correctly
if minibatch_layer:
h2 = minibatch(h1)
else:
h2 = tf.nn.relu(linear(h1, h_dim * 2, scope='d2'))
h3 = tf.sigmoid(linear(h2, 1, scope='d3'))
return h3
def minibatch(input, num_kernels=5, kernel_dim=3):
x = linear(input, num_kernels * kernel_dim, scope='minibatch', stddev=0.02)
activation = tf.reshape(x, (-1, num_kernels, kernel_dim))
diffs = tf.expand_dims(activation, 3) - \
tf.expand_dims(tf.transpose(activation, [1, 2, 0]), 0)
abs_diffs = tf.reduce_sum(tf.abs(diffs), 2)
minibatch_features = tf.reduce_sum(tf.exp(-abs_diffs), 2)
return tf.concat([input, minibatch_features], 1)
def optimizer(loss, var_list):
learning_rate = 0.001
step = tf.Variable(0, trainable=False)
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(
loss,
global_step=step,
var_list=var_list
)
return optimizer
def log(x):
'''
Sometimes discriminiator outputs can reach values close to
(or even slightly less than) zero due to numerical rounding.
This just makes sure that we exclude those values so that we don't
end up with NaNs during optimisation.
'''
return tf.log(tf.maximum(x, 1e-5))
class GAN(object):
def __init__(self, params):
# This defines the generator network - it takes samples from a noise
# distribution as input, and passes them through an MLP.
with tf.variable_scope('G'):
self.z = tf.placeholder(tf.float32, shape=(params.batch_size, 1))
self.G = generator(self.z, params.hidden_size)
# The discriminator tries to tell the difference between samples from
# the true data distribution (self.x) and the generated samples
# (self.z).
#
# Here we create two copies of the discriminator network
# that share parameters, as you cannot use the same network with
# different inputs in TensorFlow.
self.x = tf.placeholder(tf.float32, shape=(params.batch_size, 1))
with tf.variable_scope('D'):
self.D1 = discriminator(
self.x,
params.hidden_size,
params.minibatch
)
with tf.variable_scope('D', reuse=True):
self.D2 = discriminator(
self.G,
params.hidden_size,
params.minibatch
)
# Define the loss for discriminator and generator networks
# (see the original paper for details), and create optimizers for both
self.loss_d = tf.reduce_mean(-log(self.D1) - log(1 - self.D2))
self.loss_g = tf.reduce_mean(-log(self.D2))
vars = tf.trainable_variables()
self.d_params = [v for v in vars if v.name.startswith('D/')]
self.g_params = [v for v in vars if v.name.startswith('G/')]
self.opt_d = optimizer(self.loss_d, self.d_params)
self.opt_g = optimizer(self.loss_g, self.g_params)
def train(model, data, gen, params):
anim_frames = []
with tf.Session() as session:
tf.local_variables_initializer().run()
tf.global_variables_initializer().run()
for step in range(params.num_steps + 1):
# update discriminator
x = data.sample(params.batch_size)
z = gen.sample(params.batch_size)
loss_d, _, = session.run([model.loss_d, model.opt_d], {
model.x: np.reshape(x, (params.batch_size, 1)),
model.z: np.reshape(z, (params.batch_size, 1))
})
# update generator
z = gen.sample(params.batch_size)
loss_g, _ = session.run([model.loss_g, model.opt_g], {
model.z: np.reshape(z, (params.batch_size, 1))
})
if step % params.log_every == 0:
print('{}: {:.4f}\t{:.4f}'.format(step, loss_d, loss_g))
if params.anim_path and (step % params.anim_every == 0):
anim_frames.append(
samples(model, session, data, gen.range, params.batch_size)
)
if params.anim_path:
save_animation(anim_frames, params.anim_path, gen.range)
else:
samps = samples(model, session, data, gen.range, params.batch_size)
plot_distributions(samps, gen.range)
def samples(
model,
session,
data,
sample_range,
batch_size,
num_points=10000,
num_bins=100
):
'''
Return a tuple (db, pd, pg), where db is the current decision
boundary, pd is a histogram of samples from the data distribution,
and pg is a histogram of generated samples.
'''
xs = np.linspace(-sample_range, sample_range, num_points)
bins = np.linspace(-sample_range, sample_range, num_bins)
# decision boundary
db = np.zeros((num_points, 1))
for i in range(num_points // batch_size):
db[batch_size * i:batch_size * (i + 1)] = session.run(
model.D1,
{
model.x: np.reshape(
xs[batch_size * i:batch_size * (i + 1)],
(batch_size, 1)
)
}
)
# data distribution
d = data.sample(num_points)
pd, _ = np.histogram(d, bins=bins, density=True)
# generated samples
zs = np.linspace(-sample_range, sample_range, num_points)
g = np.zeros((num_points, 1))
for i in range(num_points // batch_size):
g[batch_size * i:batch_size * (i + 1)] = session.run(
model.G,
{
model.z: np.reshape(
zs[batch_size * i:batch_size * (i + 1)],
(batch_size, 1)
)
}
)
pg, _ = np.histogram(g, bins=bins, density=True)
return db, pd, pg
def plot_distributions(samps, sample_range):
db, pd, pg = samps
db_x = np.linspace(-sample_range, sample_range, len(db))
p_x = np.linspace(-sample_range, sample_range, len(pd))
f, ax = plt.subplots(1)
ax.plot(db_x, db, label='decision boundary')
ax.set_ylim(0, 1)
plt.plot(p_x, pd, label='real data')
plt.plot(p_x, pg, label='generated data')
plt.title('1D Generative Adversarial Network')
plt.xlabel('Data values')
plt.ylabel('Probability density')
plt.legend()
plt.show()
def save_animation(anim_frames, anim_path, sample_range):
f, ax = plt.subplots(figsize=(6, 4))
f.suptitle('1D Generative Adversarial Network', fontsize=15)
plt.xlabel('Data values')
plt.ylabel('Probability density')
ax.set_xlim(-6, 6)
ax.set_ylim(0, 1.4)
line_db, = ax.plot([], [], label='decision boundary')
line_pd, = ax.plot([], [], label='real data')
line_pg, = ax.plot([], [], label='generated data')
frame_number = ax.text(
0.02,
0.95,
'',
horizontalalignment='left',
verticalalignment='top',
transform=ax.transAxes
)
ax.legend()
db, pd, _ = anim_frames[0]
db_x = np.linspace(-sample_range, sample_range, len(db))
p_x = np.linspace(-sample_range, sample_range, len(pd))
def init():
line_db.set_data([], [])
line_pd.set_data([], [])
line_pg.set_data([], [])
frame_number.set_text('')
return (line_db, line_pd, line_pg, frame_number)
def animate(i):
frame_number.set_text(
'Frame: {}/{}'.format(i, len(anim_frames))
)
db, pd, pg = anim_frames[i]
line_db.set_data(db_x, db)
line_pd.set_data(p_x, pd)
line_pg.set_data(p_x, pg)
return (line_db, line_pd, line_pg, frame_number)
anim = animation.FuncAnimation(
f,
animate,
init_func=init,
frames=len(anim_frames),
blit=True
)
anim.save(anim_path, fps=30, extra_args=['-vcodec', 'libx264'])
def main(args):
model = GAN(args)
train(model, DataDistribution(), GeneratorDistribution(range=8), args)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--num-steps', type=int, default=5000,
help='the number of training steps to take')
parser.add_argument('--hidden-size', type=int, default=4,
help='MLP hidden size')
parser.add_argument('--batch-size', type=int, default=8,
help='the batch size')
parser.add_argument('--minibatch', action='store_true',
help='use minibatch discrimination')
parser.add_argument('--log-every', type=int, default=10,
help='print loss after this many steps')
parser.add_argument('--anim-path', type=str, default=None,
help='path to the output animation file')
parser.add_argument('--anim-every', type=int, default=1,
help='save every Nth frame for animation')
return parser.parse_args()
if __name__ == '__main__':
main(parse_args())