Skip to content
This repository has been archived by the owner on Apr 3, 2022. It is now read-only.

updates to be compatible with tensorflow 1.1 #34

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion srez_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def setup_tensorflow():
random.seed(FLAGS.random_seed)
np.random.seed(FLAGS.random_seed)

summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)
summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

return sess, summary_writer

Expand Down
14 changes: 7 additions & 7 deletions srez_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def _discriminator_model(sess, features, disc_input):
mapsize = 3
layers = [64, 128, 256, 512]

old_vars = tf.all_variables()
old_vars = tf.global_variables()

model = Model('DIS', 2*disc_input - 1)

Expand All @@ -352,7 +352,7 @@ def _discriminator_model(sess, features, disc_input):
model.add_conv2d(1, mapsize=1, stride=1, stddev_factor=stddev_factor)
model.add_mean()

new_vars = tf.all_variables()
new_vars = tf.global_variables()
disc_vars = list(set(new_vars) - set(old_vars))

return model.get_output(), disc_vars
Expand All @@ -363,7 +363,7 @@ def _generator_model(sess, features, labels, channels):
mapsize = 3
res_units = [256, 128, 96]

old_vars = tf.all_variables()
old_vars = tf.global_variables()

# See Arxiv 1603.05027
model = Model('GEN', features)
Expand Down Expand Up @@ -396,7 +396,7 @@ def _generator_model(sess, features, labels, channels):
model.add_conv2d(channels, mapsize=1, stride=1, stddev_factor=1.)
model.add_sigmoid()

new_vars = tf.all_variables()
new_vars = tf.global_variables()
gene_vars = list(set(new_vars) - set(old_vars))

return model.get_output(), gene_vars
Expand Down Expand Up @@ -449,7 +449,7 @@ def _downscale(images, K):

def create_generator_loss(disc_output, gene_output, features):
# I.e. did we fool the discriminator?
cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(disc_output, tf.ones_like(disc_output))
cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_output, labels=tf.ones_like(disc_output))
gene_ce_loss = tf.reduce_mean(cross_entropy, name='gene_ce_loss')

# I.e. does the result look like the feature?
Expand All @@ -466,10 +466,10 @@ def create_generator_loss(disc_output, gene_output, features):

def create_discriminator_loss(disc_real_output, disc_fake_output):
# I.e. did we correctly identify the input as real or not?
cross_entropy_real = tf.nn.sigmoid_cross_entropy_with_logits(disc_real_output, tf.ones_like(disc_real_output))
cross_entropy_real = tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real_output, labels=tf.ones_like(disc_real_output))
disc_real_loss = tf.reduce_mean(cross_entropy_real, name='disc_real_loss')

cross_entropy_fake = tf.nn.sigmoid_cross_entropy_with_logits(disc_fake_output, tf.zeros_like(disc_fake_output))
cross_entropy_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake_output, labels=tf.zeros_like(disc_fake_output))
disc_fake_loss = tf.reduce_mean(cross_entropy_fake, name='disc_fake_loss')

return disc_real_loss, disc_fake_loss
Expand Down
8 changes: 4 additions & 4 deletions srez_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def _summarize_progress(train_data, feature, label, gene_output, batch, suffix,

clipped = tf.maximum(tf.minimum(gene_output, 1.0), 0.0)

image = tf.concat(2, [nearest, bicubic, clipped, label])
image = tf.concat(axis=2, values=[nearest, bicubic, clipped, label])

image = image[0:max_samples,:,:,:]
image = tf.concat(0, [image[i,:,:,:] for i in range(max_samples)])
image = tf.concat(axis=0, values=[image[i,:,:,:] for i in range(max_samples)])
image = td.sess.run(image)

filename = 'batch%06d_%s.png' % (batch, suffix)
Expand Down Expand Up @@ -62,8 +62,8 @@ def _save_checkpoint(train_data, batch):
def train_model(train_data):
td = train_data

summaries = tf.merge_all_summaries()
td.sess.run(tf.initialize_all_variables())
summaries = tf.summary.merge_all()
td.sess.run(tf.global_variables_initializer())

lrval = FLAGS.learning_rate_start
start_time = time.time()
Expand Down