From 30cbea78c18514d8965708d79ad08c56f441c5a8 Mon Sep 17 00:00:00 2001 From: Matthias Planitzer Date: Wed, 16 Aug 2017 11:32:05 +0200 Subject: [PATCH 1/4] updates to be compatible with tensorflow 1.2.1 --- srez_train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/srez_train.py b/srez_train.py index 8f90343..6a3eb7b 100644 --- a/srez_train.py +++ b/srez_train.py @@ -19,10 +19,10 @@ def _summarize_progress(train_data, feature, label, gene_output, batch, suffix, clipped = tf.maximum(tf.minimum(gene_output, 1.0), 0.0) - image = tf.concat(2, [nearest, bicubic, clipped, label]) + image = tf.concat([nearest, bicubic, clipped, label], 2) image = image[0:max_samples,:,:,:] - image = tf.concat(0, [image[i,:,:,:] for i in range(max_samples)]) + image = tf.concat([image[i,:,:,:] for i in range(max_samples)], 0) image = td.sess.run(image) filename = 'batch%06d_%s.png' % (batch, suffix) @@ -62,8 +62,8 @@ def _save_checkpoint(train_data, batch): def train_model(train_data): td = train_data - summaries = tf.merge_all_summaries() - td.sess.run(tf.initialize_all_variables()) + summaries = tf.summary.merge_all() + td.sess.run(tf.global_variables_initializer()) lrval = FLAGS.learning_rate_start start_time = time.time() From 6fd83f9328c5b85b1603a1878b857b63d47cc8a0 Mon Sep 17 00:00:00 2001 From: Matthias Planitzer Date: Wed, 16 Aug 2017 11:32:50 +0200 Subject: [PATCH 2/4] updates to be compatible with tensorflow 1.2.1 --- srez_model.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/srez_model.py b/srez_model.py index 3075ae7..aae4f59 100644 --- a/srez_model.py +++ b/srez_model.py @@ -326,7 +326,7 @@ def _discriminator_model(sess, features, disc_input): mapsize = 3 layers = [64, 128, 256, 512] - old_vars = tf.all_variables() + old_vars = tf.global_variables() model = Model('DIS', 2*disc_input - 1) @@ -352,7 +352,7 @@ def _discriminator_model(sess, features, disc_input): model.add_conv2d(1, mapsize=1, stride=1, stddev_factor=stddev_factor) model.add_mean() - new_vars = tf.all_variables() + new_vars = tf.global_variables() disc_vars = list(set(new_vars) - set(old_vars)) return model.get_output(), disc_vars @@ -363,7 +363,7 @@ def _generator_model(sess, features, labels, channels): mapsize = 3 res_units = [256, 128, 96] - old_vars = tf.all_variables() + old_vars = tf.global_variables() # See Arxiv 1603.05027 model = Model('GEN', features) @@ -396,7 +396,7 @@ def _generator_model(sess, features, labels, channels): model.add_conv2d(channels, mapsize=1, stride=1, stddev_factor=1.) model.add_sigmoid() - new_vars = tf.all_variables() + new_vars = tf.global_variables() gene_vars = list(set(new_vars) - set(old_vars)) return model.get_output(), gene_vars @@ -449,7 +449,7 @@ def _downscale(images, K): def create_generator_loss(disc_output, gene_output, features): # I.e. did we fool the discriminator? - cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(disc_output, tf.ones_like(disc_output)) + cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=disc_output, logits=tf.ones_like(disc_output)) gene_ce_loss = tf.reduce_mean(cross_entropy, name='gene_ce_loss') # I.e. does the result look like the feature? @@ -466,10 +466,10 @@ def create_generator_loss(disc_output, gene_output, features): def create_discriminator_loss(disc_real_output, disc_fake_output): # I.e. did we correctly identify the input as real or not? - cross_entropy_real = tf.nn.sigmoid_cross_entropy_with_logits(disc_real_output, tf.ones_like(disc_real_output)) + cross_entropy_real = tf.nn.sigmoid_cross_entropy_with_logits(labels=disc_real_output, logits=tf.ones_like(disc_real_output)) disc_real_loss = tf.reduce_mean(cross_entropy_real, name='disc_real_loss') - cross_entropy_fake = tf.nn.sigmoid_cross_entropy_with_logits(disc_fake_output, tf.zeros_like(disc_fake_output)) + cross_entropy_fake = tf.nn.sigmoid_cross_entropy_with_logits(labels=disc_fake_output, logits=tf.zeros_like(disc_fake_output)) disc_fake_loss = tf.reduce_mean(cross_entropy_fake, name='disc_fake_loss') return disc_real_loss, disc_fake_loss From d38bbc728d4a5edda122e0760632d8a02a03f278 Mon Sep 17 00:00:00 2001 From: Matthias Planitzer Date: Wed, 16 Aug 2017 11:33:20 +0200 Subject: [PATCH 3/4] updates to be compatible with tensorflow 1.2.1 --- srez_main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/srez_main.py b/srez_main.py index beaea9e..b73bb26 100644 --- a/srez_main.py +++ b/srez_main.py @@ -100,7 +100,7 @@ def setup_tensorflow(): random.seed(FLAGS.random_seed) np.random.seed(FLAGS.random_seed) - summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph) + summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph) return sess, summary_writer From 09b0033f1f7852738cc18e59aa2fef58d0b69c65 Mon Sep 17 00:00:00 2001 From: Matthias Planitzer Date: Wed, 16 Aug 2017 12:46:27 +0200 Subject: [PATCH 4/4] require tensorflow 1.2.1 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index d769b8c..3454c1a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,4 @@ moviepy==0.2.2.11 numpy==1.11.1 scipy==0.18.0 six==1.10.0 -tensorflow==0.10.0rc0 +tensorflow==1.2.1