-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGAN_example_2.py
143 lines (110 loc) · 5.59 KB
/
GAN_example_2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./mnist/data/", one_hot=True)
total_epoch = 200
batch_size = 100
n_hidden = 256
n_input = 28 * 28
n_noise = 128
n_class = 10
# 신경망 모델 구성
# X : 실제 데이터들, Y : 최종 출력되는 데이터(0 ~ 9), Z : 가짜 데이터들
X = tf.placeholder(tf.float32, [None, n_input])
# 노이즈와 실제 이미지에, 그에 해당하는 숫자에 대한 정보를 넣어주기 위해 사용합니다.
Y = tf.placeholder(tf.float32, [None, n_class])
Z = tf.placeholder(tf.float32, [None, n_noise])
def generator(noise, labels):
with tf.variable_scope('generator'):
# noise 값에 labels 정보를 추가합니다.
# 학습을 하면서 noise는 해당 label의 실제 데이터에 근접해 간다
inputs = tf.concat([noise, labels], 1)
# TensorFlow 에서 제공하는 유틸리티 함수를 이용해 신경망을 매우 간단하게 구성할 수 있습니다.
hidden = tf.layers.dense(inputs, n_hidden,
activation=tf.nn.relu)
output = tf.layers.dense(hidden, n_input,
activation=tf.nn.sigmoid)
return output
def discriminator(inputs, labels, reuse=None):
with tf.variable_scope('discriminator') as scope:
# 노이즈에서 생성한 이미지와 실제 이미지를 판별하는 모델의 변수를 동일하게 하기 위해,
# 이전에 사용되었던 변수를 재사용하도록 합니다.
if reuse:
scope.reuse_variables()
inputs = tf.concat([inputs, labels], 1)
hidden = tf.layers.dense(inputs, n_hidden,
activation=tf.nn.relu)
output = tf.layers.dense(hidden, 1,
activation=None)
return output
def get_noise(batch_size, n_noise):
return np.random.uniform(-1., 1., size=[batch_size, n_noise])
# 생성 모델과 판별 모델에 Y 즉, labels 정보를 추가하여
# labels 정보에 해당하는 이미지를 생성할 수 있도록 유도합니다.
G = generator(Z, Y)
D_real = discriminator(X, Y)
D_gene = discriminator(G, Y, True)
# 손실함수는 다음을 참고하여 GAN 논문에 나온 방식과는 약간 다르게 작성하였습니다.
# http://bamos.github.io/2016/08/09/deep-completion/
# 진짜 이미지를 판별하는 D_real 값은 1에 가깝도록,
# 가짜 이미지를 판별하는 D_gene 값은 0에 가깝도록 하는 손실 함수입니다.
loss_D_real = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=D_real, labels=tf.ones_like(D_real)))
loss_D_gene = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=D_gene, labels=tf.zeros_like(D_gene)))
# loss_D_real 과 loss_D_gene 을 더한 뒤 이 값을 최소화 하도록 최적화합니다.
loss_D = loss_D_real + loss_D_gene
# 가짜 이미지를 진짜에 가깝게 만들도록 생성망을 학습시키기 위해, D_gene 을 최대한 1에 가깝도록 만드는 손실함수입니다.
loss_G = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=D_gene, labels=tf.ones_like(D_gene)))
# TensorFlow 에서 제공하는 유틸리티 함수를 이용해
# discriminator 와 generator scope 에서 사용된 변수들을 쉽게 가져올 수 있습니다.
vars_D = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope='discriminator')
vars_G = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope='generator')
train_D = tf.train.AdamOptimizer().minimize(loss_D,
var_list=vars_D)
train_G = tf.train.AdamOptimizer().minimize(loss_G,
var_list=vars_G)
#########
# 신경망 모델 학습
######
sess = tf.Session()
sess.run(tf.global_variables_initializer())
total_batch = int(mnist.train.num_examples/batch_size)
loss_val_D, loss_val_G = 0, 0
for epoch in range(total_epoch):
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
noise = get_noise(batch_size, n_noise)
_, loss_val_D = sess.run([train_D, loss_D],
feed_dict={X: batch_xs, Y: batch_ys, Z: noise})
_, loss_val_G = sess.run([train_G, loss_G],
feed_dict={Y: batch_ys, Z: noise})
print('Epoch:', '%04d' % epoch,
'D loss: {:.4}'.format(loss_val_D),
'G loss: {:.4}'.format(loss_val_G))
#########
# 학습이 되어가는 모습을 보기 위해 주기적으로 레이블에 따른 이미지를 생성하여 저장
######
if epoch == 0 or (epoch + 1) % 3 == 0:
sample_size = 10
noise = get_noise(sample_size, n_noise)
samples = sess.run(G,
feed_dict={Y: mnist.test.labels[:sample_size],
Z: noise})
fig, ax = plt.subplots(2, sample_size, figsize=(sample_size, 2))
for i in range(sample_size):
ax[0][i].set_axis_off()
ax[1][i].set_axis_off()
# 실제 데이터와 가짜 데이터를 imshow
ax[0][i].imshow(np.reshape(mnist.test.images[i], (28, 28)))
ax[1][i].imshow(np.reshape(samples[i], (28, 28)))
plt.savefig('041217/samples3/{}.png'.format(str(epoch).zfill(3)), bbox_inches='tight')
plt.close(fig)
print('최적화 완료!')