forked from david-gpu/srez
-
Notifications
You must be signed in to change notification settings - Fork 1
/
srez_train.py
120 lines (91 loc) · 3.94 KB
/
srez_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import numpy as np
import os.path
import scipy.misc
import tensorflow as tf
import time
FLAGS = tf.app.flags.FLAGS
def _summarize_progress(train_data, feature, label, gene_output, batch, suffix, max_samples=8):
td = train_data
size = [label.shape[1], label.shape[2]]
nearest = tf.image.resize_nearest_neighbor(feature, size)
nearest = tf.maximum(tf.minimum(nearest, 1.0), 0.0)
bicubic = tf.image.resize_bicubic(feature, size)
bicubic = tf.maximum(tf.minimum(bicubic, 1.0), 0.0)
clipped = tf.maximum(tf.minimum(gene_output, 1.0), 0.0)
image = tf.concat([nearest, bicubic, clipped, label], 2)
image_op = tf.summary.image('generator output', image, max_samples)
image_summary = td.sess.run(image_op)
td.summary_writer.add_summary(image_summary, batch)
image = image[0:max_samples,:,:,:]
image = tf.concat([image[i,:,:,:] for i in range(max_samples)], 0)
image = td.sess.run(image)
filename = 'batch%06d_%s.png' % (batch, suffix)
filename = os.path.join(FLAGS.train_dir, filename)
scipy.misc.toimage(image, cmin=0., cmax=1.).save(filename)
print(" Saved %s" % (filename,))
def _save_checkpoint(train_data, batch):
td = train_data
oldname = 'checkpoint_old.txt'
newname = 'checkpoint_new.txt'
oldname = os.path.join(FLAGS.checkpoint_dir, oldname)
newname = os.path.join(FLAGS.checkpoint_dir, newname)
# Delete oldest checkpoint
try:
tf.gfile.Remove(oldname)
tf.gfile.Remove(oldname + '.meta')
except:
pass
# Rename old checkpoint
try:
tf.gfile.Rename(newname, oldname)
tf.gfile.Rename(newname + '.meta', oldname + '.meta')
except:
pass
# Generate new checkpoint
saver = tf.train.Saver()
saver.save(td.sess, newname)
print(" Checkpoint saved")
def train_model(train_data):
td = train_data
summaries = tf.summary.merge_all()
td.sess.run(tf.global_variables_initializer())
lrval = FLAGS.learning_rate_start
start_time = time.time()
done = False
batch = 0
assert FLAGS.learning_rate_half_life % 10 == 0
# Cache test features and labels (they are small)
test_feature, test_label = td.sess.run([td.test_features, td.test_labels])
while not done:
batch += 1
gene_loss = disc_real_loss = disc_fake_loss = -1.234
feed_dict = {td.learning_rate : lrval}
if FLAGS.loss == 'wgan':
_ = td.sess.run(td.d_clip, feed_dict=feed_dict)
ops = [td.gene_minimize, td.disc_minimize, td.gene_loss, td.disc_real_loss, td.disc_fake_loss]
_, _, gene_loss, disc_real_loss, disc_fake_loss = td.sess.run(ops, feed_dict=feed_dict)
if batch % 10 == 0:
# Show we are alive
elapsed = int(time.time() - start_time)/60
print('Progress[%3d%%], ETA[%4dm], Batch [%4d], G_Loss[%3.3f], D_Real_Loss[%3.3f], D_Fake_Loss[%3.3f]' %
(int(100*elapsed/FLAGS.train_time), FLAGS.train_time - elapsed,
batch, gene_loss, disc_real_loss, disc_fake_loss))
batch_summaries = td.sess.run(summaries)
td.summary_writer.add_summary(batch_summaries, batch)
# Finished?
current_progress = elapsed / FLAGS.train_time
if current_progress >= 1.0:
done = True
# Update learning rate
if batch % FLAGS.learning_rate_half_life == 0:
lrval *= .5
if batch % FLAGS.summary_period == 0:
# Show progress with test features
feed_dict = {td.gene_minput: test_feature}
gene_output = td.sess.run(td.gene_moutput, feed_dict=feed_dict)
_summarize_progress(td, test_feature, test_label, gene_output, batch, 'out')
if batch % FLAGS.checkpoint_period == 0:
# Save checkpoint
_save_checkpoint(td, batch)
_save_checkpoint(td, batch)
print('Finished training!')