Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade from r11 to r12 prodeuces "Variables not defined" when using any optimizer but GradientDescentOptimizer #6220

Closed
germanRos-TRI opened this issue Dec 9, 2016 · 21 comments
Assignees
Labels

Comments

@germanRos-TRI
Copy link

germanRos-TRI commented Dec 9, 2016

After a recent upgrade to the latest version of tensorflow in github, several things stop working. I found out that all the optimizers, such as Adam or Adagrad are now producing an error related to variable scope that I have not managed to solve yet. However, GradientDescentOptimizer works fine.

It may be related to the issue: #5652

The error looks like this:

File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/variable_scope.py", line 651, in _get_single_variable
    "VarScope?" % name)
ValueError: Variable filter/Adadelta/ does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=None in VarScope?

It works fine with tensorflow r11

Operating System: Ubuntu 16 and Ubuntu 14
Installed version of CUDA and cuDNN: cuda 8.0, cuda 5.1
cuda.txt
The commit hash 6dc8dea
Build time: Wed Nov 2 17:54:14 2016 (1478109254)
Build timestamp: 1478109254
Build timestamp as int: 1478109254

Find below a minimal version that causes the error:

import tensorflow as tf
import pdb

def main():

    ## !!! change this to test the different behaviors !!!
    #optimizer = tf.train.GradientDescentOptimizer(1e-3)                 # This one is working
    optimizer = tf.train.AdamOptimizer(1e-3, beta1=0.9, beta2=0.999999) # This one is not working
    #optimizer = tf.train.AdagradOptimizer(1e-3)                         # This one is not working
    #optimizer = tf.train.AdadeltaOptimizer(1e-3)                        # This one is not working
	
    list_grads = []
    for i in xrange(2):
        with tf.device('/gpu:%d' % i):
            with tf.name_scope('%d' % i) as scope:
                W = tf.get_variable(name="filter", initializer=tf.random_uniform_initializer(dtype=tf.float32), shape=[5, 1])
                X = tf.get_variable(name="data", initializer=tf.random_uniform_initializer(dtype=tf.float32), shape=[5, 1])
                Y_ = tf.get_variable(name="out", initializer=tf.random_uniform_initializer(dtype=tf.float32), shape=[5, 1])
                Y = W+X
                loss =tf.reduce_mean(Y-Y_)
                grad = optimizer.compute_gradients(loss)
                list_grads.append(grad)

                tf.get_variable_scope().reuse_variables()	
    
    grads = list_grads[0] + list_grads[1]
    #pdb.set_trace()

    op_train = optimizer.apply_gradients(grads)

    init_global = tf.global_variables_initializer()
    init_local =  tf.local_variables_initializer()

    sess = tf.Session()
    sess.run([init_global, init_local])

    _, sol = sess.run([op_train, loss])
    print(str(sol))

if (__name__ == '__main__'):
	main()
@germanRos-TRI
Copy link
Author

Comment and uncomment the different optimizers to see the behavior. As explained, the only work working is GradientDescentOptimizer. This behavior does not occur in version r11 and it does not happen either if the averaging is not performed. Any clue?

@sherrym
Copy link
Contributor

sherrym commented Dec 10, 2016

In particular, this commit causes the problem: 0fc86dd.

@germanRos-TRI
Copy link
Author

Please, tell me if I can help.

@sherrym sherrym removed their assignment Dec 10, 2016
@sherrym
Copy link
Contributor

sherrym commented Dec 10, 2016

You can revert the changes to slot_creater.py or fix the changes and send a PR. Thanks.

@lukaszkaiser
Copy link
Contributor

Sorry sherry -- the current behaviour is correct. Your code is leaking reuse -- it just wasn't checked before. It could cause all other troubles, and I think we should correct the leaky reuse cases, not revert the slot change. I'll write more on the test cases, closing this.

@germanRos-TRI
Copy link
Author

Then, just to be clear. How do we get the desired results? Does the reuse need to be done in a different way? This has been directly taken from the cifar10 multi-gpu example.

Thanks

@lukaszkaiser
Copy link
Contributor

To clarify, we just need to put a scope around the model-construction part.

with tf.variable_scope(tf.get_variable_scope()) as scope:
  for i in xrange(2):
    ... code as before until ...reuse_varables() ....

grads = list_grads[0] + list_grads[1]
... rest of code as before ...

Hope that helps!

@germanRos-TRI
Copy link
Author

Thanks lukaszkaiser. It works perfectly fine now!

@wookayin
Copy link
Contributor

wookayin commented Jan 15, 2017

@lukaszkaiser Hello, I found that your workaround to put a variable_scope which wraps the outermost num_gpus loop, but I am still confused why it does eliminate the error.

with tf.variable_scope(tf.get_variable_scope()) as vscope:
  for i in xrange(FLAGS.num_gpus):
    with tf.device('/gpu:%d' % i):
      with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope:
        loss = tower_loss(scope)
        tf.get_variable_scope().reuse_variables()     # HERE

Is it just because that the tf.get_variable_scope() (which is identical to vscope) is explicitly created than the implicit default? Then, what do these two VariableScope objects differ in?

What do you mean by "leaky reuse"? Could you please clarify me?
/cc @cesc-park

@lukaszkaiser
Copy link
Contributor

Sure, let me try to clarify.

When you do tf.get_variable_scope().reuse_variables() you set the current scope to reuse variables. If you call the optimizer in such scope, it's trying to reuse slot variables, which it cannot find, so it throws an error. If you put a scope around, the tf.get_variable_scope().reuse_variables() only affects that scope, so when you exit it, you're back in the non-reusing mode, the one you want.

Hope that helps, let me know if I should clarify more.

@wookayin
Copy link
Contributor

wookayin commented Jan 16, 2017

Ah, great. Your explanation is clear and helpful. Thanks!

To sum, a thing to remember is that where the (Adam-like) optimizer acts, i.e. opt.apply_gradients(...) (which is where the error is thrown from) should lie in the scope with reuse=False in order to properly create the slot variables.

wookayin added a commit to wookayin/tensorflow-models that referenced this issue Jan 17, 2017
Without the new variable_scope, creating apply_gradient_op raises
an error that additional moving average or slot variables could not
be created. This is because of the 'leaky reuse' of variable scope,
so we correct the problem by explicitly introducing a new variable scope.

Related issues: tensorflow#901, tensorflow/tensorflow#6220
wookayin added a commit to wookayin/tensorflow-models that referenced this issue Feb 13, 2017
Without the new variable_scope, creating apply_gradient_op raises
an error that additional moving average or slot variables could not
be created. This is because of the 'leaky reuse' of variable scope,
so we correct the problem by explicitly introducing a new variable scope.

Related issues: tensorflow#901, tensorflow/tensorflow#6220
Peratham added a commit to Peratham/models that referenced this issue Mar 1, 2017
* Fix bug in relative path of shell scripts built with bazel.

* Add Bazel workspace name to fix bug in relative path of shell scripts.

* Update citation in README.md

* Revert "Add Bazel workspace name to fix bug in relative path of shell scripts."

This reverts commit a704458.

* Revert "Fix bug in relative path of shell scripts built with bazel."

This reverts commit 091d6e4.

* Add Bazel workspace name to fix bug in relative path of shell scripts.

* Fix a bug in the im2txt code where the Saver is created before the
optimizer.

* Fix bug caused by signature change of resize_images().

* fix resize image throughout

* Remove flag --config=cuda. It's not necessary and can cause a warning.

* Close the TFRecordWriter after use.

* Use tar on OSX to unzip the MSCOCO data file.

* Use open() instead of tf.gfile.FastGFile()

* Updates to syntaxnet, including update tensorflow sub-module, bazel requirement and fix trainer crash (tensorflow#479)

* syntaxnet: Cosmetic fixes recommended by python lint.

* syntaxnet: Fix crash in parser_trainer due to inconsistency between LexiconBuilder::Compute()
	   and context.pbtxt definition ('char-map' input declaration was missing).

* syntaxnet: reduce flakiness in GraphBuilderTest.

* syntaxnet: Update tensorflow submodule to version > 0.10.

* syntaxnet: Update to latest stable bazel (0.3.1).

This update comes partially to allow Tensorflow submodule to build
succesffuly. In this commit, I also update and simplify the WORKSPACE,
to avoid declaring dependencies already present in tensorflow.

* syntaxnet: Update bazel version check to require version 0.3.0

* syntaxnet: Document pip requirement, along with python mock module.

* added python3 support to read_label_file

* Fix GFile issue with numpy by using io library.

* video prediction model code

* Added STREET model for FSNS dataset

* Fix broken link in inception readme

Fixed tensorflow#529

* Revert "Use open() instead of tf.gfile.FastGFile()"

This reverts commit c6a4f78.

Fixed tensorflow/tensorflow#4981

* Fix comment of parameter "output_codes"

* Add sys.stdout.flush()

* fix end point collection to return a dict

* Fix POS tagging score of Ling et al.(2005)

For English News Corpus,
[Ling et al. (2015)](http://www.cs.cmu.edu/~lingwang/papers/emnlp2015.pdf)'s score is 
97.78 -> 97.44 (lower than SyntaxNet and Parsey Mcparseface)
according to [Andor et al. (2016)](http://arxiv.org/abs/1603.06042).

* add privacy analysis script and teacher labels required to predict the epsilon

* remove CIFAR-10 from README

* Add differential privacy training.

* fix module object has no attribute NodeDef for tensorflow 0.11 (tensorflow#572)

* fix module object has no attribute NodeDef for tensorflow 0.11

* change graph_pb2.NodeDef to tf.NodeDef

* Update cifar input following data change.

* Allow softplacement for ResNet

* doc typo

* Explicitly set state_is_tuple=False.

* make large files downloadable

* removed large binaries from this repository

* added description of binary files in privacy README.md

* typo in privacy README

* remove extra parentheses in privacy README

* Updated download instructions to match reality

* Consolidate privacy/ and differential_privacy/.

* Fix the BUILD file

* Remove privacy/ after consolidation.

Now differential_privacy and privacy are
under the same project.

* val_captions_file -> captions_val2014.json

* Remove comment that TensorFlow must be built from source.

* Implementation of Inception V4

* Update README with results for comparison.

* added semi-supervised training of the student using improved-gan (tensorflow#655)

* Updating README.md

Adding list of maintainers
Changing model links to point to tensorflow/models repository.

* fix the readme

* fix the readme

* My message

* move to a new place

* add a readme

* Get back the README

* Get back the README

* edits ro README

* edits to README

* edits to README

* Update GraphKeys.VARIABLES to GraphKeys.GLOBAL_VARIABLES

* Update README.md

Fixed typos in folders pathes

* Update GraphKeys.VARIABLES to GraphKeys.GLOBAL_VARIABLES.

* Raises AssertionError on Incomplete Vocabulary

fixes issue tensorflow#621
added a new function CheckVocab, to check for presence of a word in vocabulary

* Update data.py

* Convert resnet model to use monitored_session

* Moving example models from github.com/tensorflow/tensorflow to github.com/tensorflow/models

* Python 3 support for some inception scripts

* Made several fixes to the embedding README

* fix the error of "TypeError: ones_initializer() got multiple values for (tensorflow#777)

keyword argument 'dtype'".

* Update cifar10.py

bug fix for contrib.deprecated eliminatation in tf version 12.

* Update cifar10_input.py

bug fix for contrib.deprecated eliminatation in tf version 12.

* Update cifar10_multi_gpu_train.py

bug fix for contrib.deprecated eliminatation in tf version 12.

* Update word2vec.py

bug fix for contrib.deprecated eliminatation in tf version 12.

* Update ptb_word_lm.py

bug fix for contrib.deprecated eliminatation in tf version 12.

* Add cross conv model for next frame prediction.

* Remove all references to 'tensorflow.models' which is no longer correct

* fix neural programmer link error in README.md

* Update README.md

* DOC: Typo in resnet documentation

"resisual" => "residual"

* Removed unused import

* Re-alphabetized the README

* Word2vec can now be run if the users compile the ops on their own

* Add a link to explain the compilation command

* Replaced direct path concatenation with os.path.join

* Wording change

* Fix rnn translate in python3

* Update build_image_data.py

_bytes_feature excepted class bytes, but in python3 in class str,so use tf.compat.as_bytes for compatibility !

* im2txt: make python3 compatible adding lt and eq

__cmp__ is deprecated on python3, so it fails to compare class Caption on python3

* Ability to train the translation model on arbitrary input sources.

* slim: Typos at datasets/flowers.py

* Added README to tutorials/ recommending to the user to install TensorFlow from source

* Change installing from source to installing from nightly build

* Deleted embedding/BUILD which is no longer working (tensorflow#855)

* Fix xent call in mnist tutorial code

Fixes tensorflow#857.

* Update cluttered_mnist.py

* Update losses.py

* Update cifar10.py

* Update deep_cnn.py

* Update vgsl_model.py

* Updated calls to '..._cross_entropy_with_logits' in order to match internal version

* Added -D_GLIBCXX_USE_CXX11_ABI=0 to support g++ version 5 for word2vec

* Moved parenthesis to the right place

* Replace deprecated functions

* Replace deprecated functions

* Update deprecated function

Update based on the error message:
 WARNING:tensorflow:From ./neural_programmer/parameters.py:75 in parameters.: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.

* Replace deprecated functions

* Replace deprecated functions

* Update README.md to indicate required TensorFlow version.

* ensure output directory exists

The neural programmer model fails the first time it's run, if the output directory folder does not already exist. In this case "../model" does not exist and the function fails because the mkdir function doesn't appear to create parent folders. 
Error: 
tensorflow.python.framework.errors_impl.NotFoundError: ../model//modeltemp/

* Variables defined in ExponentialMovingAverage need not to be shared. (tensorflow#778)

* Variables defined in ExponentialMovingAverage need not to be shared.

* Address comments.

* Real NVP code

* Make comment formal

* Update the tensorflow submodule in syntaxnet in order to fix the zlib URL

* Added shape to cifar10_input.py

Fixes tensorflow#893

* Upgrade Bazel in syntaxnet Dockerfile

* Remove dated Bazel docs in syntaxnet (tensorflow#905)

Fixes tensorflow#657

* Fix typos in models/slim/README.md (tensorflow#904)

Fixed tensorflow#903

* Update resnet to run with tf r0.12 API. (tensorflow#833)

* Update resnet to run with tf r0.12 API.
1. tf.image.per_image_whitening -> tf.image.per_image_standardization
2. Use tf.summary to replace tf.image_summary, tf.scalar_summary, tf.merge_all_summaries.

* remove log

* Update the embedding README to be compatible with Mac

* update the initializer changes

* update another initializer change

* Force new instance creation in MultiRNNCell (See also CL 145094809)

* Fix regressions caused by a previous change

* Update inception model based on tf API changes: replace tf.op_scope with tf.name_scope and tf.variable_op_scope with tf.variable_scope; fix the order of arguments for tf.concat; replace tf.mul with tf.multiply.

* Modify compression tools to be Python3 compatible.

* Fix vocabulary naming (input/output vocabulary no longer has same name) (tensorflow#946)

* Updated the cifar10 model to match the internal version and to be compatible with the latest version of TensorFlow

* Sync w TF r0.12 & Bazel 0.4.3, internal updates (tensorflow#953)

* Update to the Neural GPU.

* Changes for TF 1.0 compatibility

* another xrange change + change to concat_v2

* Corrections and explanations for the updated Neural GPU model.

* Update tf.concat_v2 to tf.concat

* Removed deprecated op

Remove the deprecated `scalar_summary` and use `summary.scalar` instead. 

The current program gets the following warning:
WARNING:tensorflow:
build_graph.: scalar_summary (from tensorflow.python.ops.logging_ops) is deprecated and will be removed after 2016-11-30.
Instructions for updating:
Please switch to tf.summary.scalar. Note that tf.summary.scalar uses the node name instead of the tag. This means that TensorFlow will automatically de-duplicate summary names based on the scope they are created in. Also, passing a tensor or list of tags to a scalar summary op is no longer supported.

* typo

* Updated summaries in the tutorial models to 1.0

* Wrap the cifar10 multigpu model construction part with a variable_scope

Without the new variable_scope, creating apply_gradient_op raises
an error that additional moving average or slot variables could not
be created. This is because of the 'leaky reuse' of variable scope,
so we correct the problem by explicitly introducing a new variable scope.

Related issues: tensorflow#901, tensorflow/tensorflow#6220

* Update concat_v2 to be concat to match 1.0 final

Fixes tensorflow#1014

* Updated concat_v2 to concat for 1.0 compatibility

Updated concat_v2 to concat for version 1.0 compatibility for breaking changes introduced in version 1.0
"tf.concat now takes arguments in reversed order and with different keywords. In particular we now match NumPy order as tf.concat(values, axis, name)"

* Update resnet model API + README

* Update the evaluation code as well to print results

* Remove the specific timing from the README

* Update swivel to TFr1.0

- TF1.0 has breaking changes for tf.concat
- Replace deprecated summary api
- Replace to be deprecated initialize_all_variables

* Fixed concat order using tf_upgrade.py

* Changed deprecated tf.initialize_all_variables() to tf.global_variables_initializer()

* Fix division changing dtype to float in python3

* Make slim models a python package

* Set tf.logging verbosity to INFO

* Modify the README to reflect changes

* Sync SyntaxNet with TensorFlow r1.0 (tensorflow#1062)

* Sync SyntaxNet with TensorFlow r1.0

* Fix typo back

* Fix Dockerfile to match TensorFlow 1.0

* Fix Bazel version check (tensorflow#1069)
@yanxp
Copy link

yanxp commented Mar 14, 2017

@wagonhelm hello,when use Adam,how do you solve it?please tell me the details,thank you

taylorpaul pushed a commit to taylorpaul/cifar10_tf that referenced this issue Mar 25, 2017
Without the new variable_scope, creating apply_gradient_op raises
an error that additional moving average or slot variables could not
be created. This is because of the 'leaky reuse' of variable scope,
so we correct the problem by explicitly introducing a new variable scope.

Related issues: tensorflow/models#901, tensorflow/tensorflow#6220
@zhouyang209117
Copy link

zhouyang209117 commented Jul 12, 2017

Find below a minimal version that causes the error:

import tensorflow as tf
import numpy as np
class SimpleModel:
    def __init__(self):
        self.loss = self.calc_loss()
        self.train = self.train_model(self.loss)
    def calc_loss(self):
        W = tf.get_variable("w", [1])
        b = tf.Variable(tf.zeros([1]))
        y = W * x_data + b
        return tf.reduce_mean(tf.square(y - y_data))
    def train_model(self, loss):
        return tf.train.AdamOptimizer(0.5).minimize(loss)
        # return tf.train.GradientDescentOptimizer(0.5)
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3
s1 = SimpleModel()
tf.get_variable_scope().reuse_variables()
s2 = SimpleModel()

The error looks like this:

 File "D:\MyProgram\Install\Anaconda3\envs\tensorflow121\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 682, in _get_single_variable
    "VarScope?" % name)
ValueError: Variable embeddings/Adam_2/ does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=None in VarScope?

tensorflow version

1.2.1

@Huayra007
Copy link

Huayra007 commented Aug 10, 2017

@lukaszkaiser Hi, I've confronted with this problem when using AdamOptimizer, I've tried your suggestion but it still doesn't work. Could you please help me change the code?
Besides, I'm not very familiar with TF, wish you can help me point out anything not appropriate in this code. Thank you!

import tensorflow as tf
import numpy as np
import datetime
import matplotlib.pyplot as plt
import os
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data/')
os.environ['CUDA_VISIBLE_DEVICES']='4'

sample_image = mnist.train.next_batch(1)[0]
print (sample_image.shape)

sample_image = sample_image.reshape([28,28])
plt.imshow(sample_image,cmap='Greys')

def discriminator(images,reuse=False,):
    if(reuse):
        tf.get_variable_scope().reuse_variables()

    with tf.variable_scope('D_conv1'):
        d_w1 = tf.get_variable('d_w1',[5,5,1,32],initializer=tf.truncated_normal_initializer(stddev=0.02))
        d_b1 = tf.get_variable('d_b1',[32],initializer=tf.constant_initializer(0))
        d1 = tf.nn.conv2d(input=images,filter=d_w1,strides=[1,1,1,1],padding='SAME')+d_b1
        d1 = tf.nn.relu(d1)
        d1 = tf.nn.avg_pool(d1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')

    with tf.variable_scope('D_conv2'):
        d_w2 = tf.get_variable('d_w2',[5,5,32,64],initializer=tf.truncated_normal_initializer(stddev=0.02))
        d_b2 = tf.get_variable('d_b2',[64],initializer=tf.constant_initializer(0))
        d2 = tf.nn.conv2d(input=d1,filter=d_w2,strides=[1,1,1,1],padding='SAME')+d_b2
        d2 = tf.nn.relu(d2)
        d2 = tf.nn.avg_pool(d2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')

    with tf.variable_scope('D_fcn3'):
        d_w3 = tf.get_variable('d_w3',[7*7*64,1024],initializer=tf.truncated_normal_initializer(stddev=0.02))
        d_b3 = tf.get_variable('d_b3',[1024],initializer=tf.constant_initializer(0))
        d3 = tf.matmul(tf.reshape(d2,[-1,7*7*64]),d_w3)+d_b3
        d3 = tf.nn.relu(d3)

    with tf.variable_scope('D_fc4'):
        d_w4 = tf.get_variable('d_w4',[1024,1],initializer=tf.truncated_normal_initializer(stddev=0.02))
        d_b4 = tf.get_variable('d_b4',[1],initializer=tf.constant_initializer(0))
        d4 = tf.matmul(d3,d_w4)+d_b4
        d4 = tf.nn.sigmoid(d4,name='d4')

    return d4

def generator(z,batch_size,z_dim):
    g_w1 = tf.get_variable('g_w1',[z_dim,56*56],dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.02))
    g_b1 = tf.get_variable('g_b1',[56*56],dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.02))
    g1 = tf.matmul(z,g_w1)+g_b1
    g1 = tf.reshape(g1,[-1,56,56,1])
    g1 = tf.contrib.layers.batch_norm(g1,epsilon=1e-5,scope='bn1')
    g1 = tf.nn.relu(g1)

    g_w2 = tf.get_variable('g_w2',[3,3,1,z_dim/2],dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.02))
    g_b2 = tf.get_variable('g_b2',[z_dim/2],dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.02))
    g2 = tf.nn.conv2d(g1,filter=g_w2,strides=[1,1,1,1],padding='SAME')+g_b2
    g2 = tf.contrib.layers.batch_norm(g2,epsilon=1e-5,scope='bn2')
    g2 = tf.nn.relu(g2)
    g2 = tf.image.resize_images(g2,[56,56])

    g_w3 = tf.get_variable('g_w3',[3,3,z_dim/2,z_dim/4],dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.02))
    g_b3 = tf.get_variable('g_b3',[z_dim/4],dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.02))
    g3 = tf.nn.conv2d(g2,filter=g_w3,strides=[1,1,1,1],padding='SAME')+g_b3
    g3 = tf.contrib.layers.batch_norm(g3,epsilon=1e-5,scope='bn3')
    g3 = tf.nn.relu(g3)
    g3 = tf.image.resize_images(g3,[56,56])

    g_w4 = tf.get_variable('g_w4',[1,1,z_dim/4,1],dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.02))
    g_b4 = tf.get_variable('g_b4',[1],dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.02))
    g4 = tf.nn.conv2d(g3,filter=g_w4,strides=[1,2,2,1],padding='SAME')+g_b4
    g4 = tf.nn.sigmoid(g4)

    return g4

tf.reset_default_graph()
batch_size =100
z_dimension = 100

z_placeholder = tf.placeholder(tf.float32,[None,z_dimension],name='z_placeholder')
x_placeholder = tf.placeholder(tf.float32,[None,28,28,1],name='x_placeholder')

with tf.variable_scope(tf.get_variable_scope()):
    Gz = generator(z_placeholder,batch_size,z_dimension)
    Dx = discriminator(x_placeholder)
    Dg = discriminator(Gz,reuse=True)

d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(Dx),logits=Dx))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(Dg),logits=Dg))
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(Dg),logits=Dg))

tvars = tf.trainable_variables()
d_vars = [var for var in tvars if 'd_' in var.name]
g_vars = [var for var in tvars if 'g_' in var.name]

print ([v.name for v in d_vars])
print ([v.name for v in g_vars])
'''
d_trainer_real = tf.train.GradientDescentOptimizer(0.0003).minimize(d_loss_real,var_list=d_vars)
d_trainer_fake = tf.train.GradientDescentOptimizer(0.003).minimize(d_loss_fake,var_list=d_vars)
g_trainer = tf.train.GradientDescentOptimizer(0.001).minimize(g_loss,var_list=g_vars)
'''
with tf.variable_scope(tf.get_variable_scope()):
    d_trainer_real = tf.train.AdamOptimizer(0.0003).minimize(d_loss_real, var_list=d_vars)
    tf.get_variable_scope().reuse_variables()
    d_trainer_fake = tf.train.AdamOptimizer(0.0003).minimize(d_loss_fake, var_list=d_vars)
    g_trainer = tf.train.AdamOptimizer(0.0001).minimize(g_loss, var_list=g_vars)

tf.summary.scalar('Discriminator_loss_real',d_loss_real)
tf.summary.scalar('Discriminator_loss_fake',d_loss_fake)
tf.summary.scalar('Generator_loss',g_loss)

images_for_tensorboard = generator(z_placeholder,batch_size,z_dimension)
tf.summary.image('Generated_images',images_for_tensorboard,5)
merged = tf.summary.merge_all()
logdir = 'Tensorboard/' + datetime.datetime.now().strftime('%Y%m%d-%H%M%S') + '/'
writer = tf.summary.FileWriter(logdir,sess.graph)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

#Pre-train discriminator
for i in range(3000):
    z_batch = np.random.normal(0, 1, size=[batch_size,z_dimension])
    real_image_batch = mnist.train.next_batch(batch_size)[0].reshape([batch_size,28,28,1])
    _,__,dLossReal,dLossFake = sess.run([d_trainer_real,d_trainer_fake,d_loss_real,d_loss_fake],
                                       feed_dict={x_placeholder:real_image_batch,z_placeholder:z_batch})
    if(i%100==0):
        print ('dLossReal: ',dLossReal,'dLossFake: ',dLossFake)

#Train discriminator and generator together
for i in range(100000):
    real_image_batch = mnist.train.next_batch(batch_size)[0].reshape([batch_size,28,28,1])
    z_batch = np.random.normal(0,1,size=[batch_size,z_dimension])

    _,__,dLossReal,dLossFake = sess.run([d_trainer_real,d_trainer_fake,d_loss_real,d_loss_fake],
                                        feed_dict={x_placeholder:real_image_batch,z_placeholder:z_batch})
    z_batch = np.random.normal(0,1,size=[batch_size,z_dimension])
    _ = sess.run(g_trainer,feed_dict={z_placeholder:z_batch})

    if i%10 ==0:
        z_batch = np.random.normal(0,1,size=[batch_size,z_dimension])
        summary = sess.run(merged,feed_dict={x_placeholder:real_image_batch,z_placeholder:z_batch})
        writer.add_summary(summary,i)

    if i%100==0:
        print ('Iteration:',i,'at',datetime.datetime.now())
        z_batch = np.random.normal(0,1,size=[batch_size,z_dimension])
        generated_image = generator(z_placeholder,1,z_dimension)
        images = sess.run(generated_image,feed_dict={z_placeholder:z_batch})
        plt.imshow(images[0].reshape([28,28]),cmap='Greys')
        plt.savefig('Generated_images/'+str(i)+'.jpg')

        img = images[0].reshape([1,28,28,1])
        result = discriminator(x_placeholder)
        estimate = sess.run(result,feed_dict={x_placeholder:img})
        print ('Estimate:',estimate)

@ivanjacobs
Copy link

@Huayra007
if you remove the

images_for_tensorboard = generator(z_placeholder,batch_size,z_dimension)
tf.summary.image('Generated_images',images_for_tensorboard,5)

should be able to run it. You are calling 2 time your generator. So or you remove the snippet or you add reuse to your generator code as such:


def generator(z,batch_size,z_dim,reuse=False):
    if (reuse):
        tf.get_variable_scope().reuse_variables()

    g_w1 = tf.get_variable('g_w1',[z_dim,56*56],dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.02))
    g_b1 = tf.get_variable('g_b1',[56*56],dtype=tf.float32,initializ



and when you call it for the tensorboard as such:

with tf.variable_scope(tf.get_variable_scope()) as scope:
    images_for_tensorboard = generator(z_placeholder, batch_size, z_dimensions,reuse=True)
    tf.summary.image('Generated_images', images_for_tensorboard, 5)

I hope this helps. Good luck with your GANs ;-)

@haojitianya
Copy link

i am a student,i am not very familiar with tensorflow, i just follow @lukaszkaiser
and use with ' tf.variable_scope(tf.get_variable_scope(),reuse=tf.AUTO_REUSE) as scope:'
and delete the 'tf.get_variable_scope().reuse_variables()' my code is work .
i am runing the code of ROLO.

@farzinh
Copy link

farzinh commented Sep 22, 2018

I'm new in TF, I tried to use:
with tf.variable_scope(tf.get_variable_scope()) as scope:
but it didn't work, and I changed it to default, can you help me to change it in a right way?
I attached the python code :
exprgan.py.gz
this 4 method has tf.get_variable_scope().reuse_variables():
encoder, generator, discriminator_z and discriminator_img

@zaiedsarra
Copy link

@lukaszkaiser Hello I need your help :( , in the code i have the same error :ValueError: Variable G_fc/w does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope?
But i didn't understand haw i can change the code (it works with GradientDescent)

this is the function

def generator(self, z, y, gender=None, reuse_variables=False, enable_tile_label=True, tile_ratio=1.0):
    if reuse_variables:
        tf.get_variable_scope().reuse_variables()
    num_layers = int(np.log2(self.size_image)) - int(self.size_kernel / 2)
    if enable_tile_label:
        duplicate = int(self.num_z_channels * tile_ratio / self.y_dim)
    else:
        duplicate = 1
    z = concat_label(z, y, duplicate=duplicate)
    if enable_tile_label:
        duplicate = int(self.num_z_channels * tile_ratio / 2)
    else:
        duplicate = 1

    size_mini_map = int(self.size_image / 2 ** num_layers)

    name = 'G_fc'
    current = fc(
        input_vector=z,
        num_output_length=self.num_gen_channels * size_mini_map * size_mini_map,
        name=name,
         # reuse = reuse_variables
    )

    current = tf.reshape(current, [-1, size_mini_map, size_mini_map, self.num_gen_channels])
    current = tf.nn.relu(current)
    current = concat_label(current, y)

    for i in range(num_layers):
        name = 'G_deconv' + str(i)
        current = tf.image.resize_nearest_neighbor(current, [size_mini_map * 2 ** (i + 1), size_mini_map * 2 ** (i + 1)])
        current = custom_conv2d(input_map=current, num_output_channels=int(self.num_gen_channels / 2 ** (i + 1)), name=name)
        current = tf.nn.relu(current)
        current = concat_label(current, y)

    name = 'G_deconv' + str(i + 1)
    current = tf.image.resize_nearest_neighbor(current, [self.size_image, self.size_image])
    current = custom_conv2d(input_map=current, num_output_channels=int(self.num_gen_channels / 2 ** (i + 2)), name=name)
    current = tf.nn.relu(current)
    current = concat_label(current, y)

    name = 'G_deconv' + str(i + 2)
    current = custom_conv2d(input_map=current, num_output_channels=self.num_input_channels, name=name)

    return tf.nn.tanh(current)

And this is the G_fc function

def fc(input_vector, num_output_length, name='fc'):
with tf.variable_scope(name):
stddev = np.sqrt(1.0 / (np.sqrt(input_vector.get_shape()[-1].value * num_output_length)))
w = tf.get_variable(
name='w',
shape=[input_vector.get_shape()[1], num_output_length],
dtype=tf.float32,
initializer=tf.random_normal_initializer(stddev=stddev)
)
b = tf.get_variable(
name='b',
shape=[num_output_length],
dtype=tf.float32,
initializer=tf.constant_initializer(0.0)
)
return tf.matmul(input_vector, w) + b

@SystemErrorWang
Copy link

i got the same error in multi-gpu training script, even when i include all the model define inside a with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope: and manually added tf.get_variable_scope().reuse_variables() in each gpu. would like to know what happened, any information or suggestion is appreciated

@Traeyee
Copy link

Traeyee commented Mar 18, 2019

i got the same error in multi-gpu training script, even when i include all the model define inside a with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope: and manually added tf.get_variable_scope().reuse_variables() in each gpu. would like to know what happened, any information or suggestion is appreciated

Just do as @wookayin suggested

@SystemErrorWang
Copy link

@Traeyee thanks, fixed the problem after many attempts of changing tf.name_scope and tf.variable_scope.

MarkDaoust pushed a commit to tensorflow/examples that referenced this issue Mar 17, 2020
Without the new variable_scope, creating apply_gradient_op raises
an error that additional moving average or slot variables could not
be created. This is because of the 'leaky reuse' of variable scope,
so we correct the problem by explicitly introducing a new variable scope.

Related issues: tensorflow/models#901, tensorflow/tensorflow#6220
heilov9 added a commit to heilov9/examples that referenced this issue Aug 7, 2024
Without the new variable_scope, creating apply_gradient_op raises
an error that additional moving average or slot variables could not
be created. This is because of the 'leaky reuse' of variable scope,
so we correct the problem by explicitly introducing a new variable scope.

Related issues: tensorflow/models#901, tensorflow/tensorflow#6220
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests