Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

trainable flag does not work for batch normalization layer #4762

Closed
nes123 opened this issue Dec 18, 2016 · 26 comments
Closed

trainable flag does not work for batch normalization layer #4762

nes123 opened this issue Dec 18, 2016 · 26 comments

Comments

@nes123
Copy link

nes123 commented Dec 18, 2016

I work with keras 1.0.2 with tensorflow.
I try to freeze some layers in the network, it works well for convolutions and FC but not for the batch normalization layer. I print the weights of the layer before and after one epoch and I see a changes. any ideas ?

@fchollet
Copy link
Collaborator

Please reformulate your question in a clearer fashion.

@jaekyeom
Copy link

jaekyeom commented Feb 4, 2017

I think I'm experiencing the same issue.
My Keras is 1.2.1 and Tensorflow is 0.12.1.
I spent some time wondering why some weights are changing even if I set trainable=False for all models before compiling, and then came here.
If I eliminate the BatchNormalization layers, they stay the same.

@jaekyeom
Copy link

jaekyeom commented Feb 4, 2017

As it seems like a problem with calling self.add_update(), not using mode=0 could be a temporary workaround.

@scott-vsi
Copy link
Contributor

You may need to set the learning phase to testing (0): e.g.,

from keras import backend as K
K.set_learning_phase(0)  # all new operations will be in test mode from now on

per https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html

@jaekyeom
Copy link

jaekyeom commented Feb 8, 2017

Actually, what I wanted was to prevent some part of the network from being trained, and the part contained BatchNormalization layers.

@Tokukawa
Copy link

Tokukawa commented Mar 2, 2017

A little bit more of context

from keras.layers import normalization
from keras.models import Sequential
import numpy as np

model0 = Sequential()
norm_m0 = normalization.BatchNormalization(input_shape=(10,), momentum=0.8)
model0.add(norm_m0)

model0.summary()

model1 = Sequential()
norm_m1 = normalization.BatchNormalization(input_shape=(10,), momentum=0.8)
model1.add(norm_m1)
for layer in model1.layers:
    layer.trainable = False
model1.compile(loss='mse', optimizer='sgd')

print("Shape batch normalization: {}.".format(len(model0.layers[-1].get_weights())))
print("Before training")
print([np.array_equal(w0, w1)for w0,w1 in zip(model0.layers[-1].get_weights(), model1.layers[-1].get_weights())])

X = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10))
model1.fit(X, X, nb_epoch=4, verbose=0)

print("After training")
print([np.array_equal(w0,w1)for w0,w1 in zip(model0.layers[-1].get_weights(), model1.layers[-1].get_weights())])
Shape batch normalization: 4.
Before training
[True, True, True, True]
After training
[True, True, False, False]

We instantiate two models.
Before the training all the weights of the two models are equal. After the training of the second model with all the layers frozen, the weights are different.
However if you change the default mode from 0, you get the correct beaviour

from keras.layers import normalization
from keras.models import Sequential
import numpy as np

model0 = Sequential()
norm_m0 = normalization.BatchNormalization(mode=1, input_shape=(10,), momentum=0.8)
model0.add(norm_m0)

model0.summary()

model1 = Sequential()
norm_m1 = normalization.BatchNormalization(mode=1, input_shape=(10,), momentum=0.8)
model1.add(norm_m1)

for layer in model1.layers:
    layer.trainable = False

model1.compile(loss='mse', optimizer='sgd')


print("Shape batch normalization: {}.".format(len(model0.layers[-1].get_weights())))
print("Before training")
print([np.array_equal(w0, w1)for w0,w1 in zip(model0.layers[-1].get_weights(), model1.layers[-1].get_weights())])

X = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10))
model1.fit(X, X, nb_epoch=4, verbose=0)

print("After training")
print([np.array_equal(w0,w1)for w0,w1 in zip(model0.layers[-1].get_weights(), model1.layers[-1].get_weights())])
Shape batch normalization: 4.
Before training
[True, True, True, True]
After training
[True, True, True, True]

@thematrixduo
Copy link

thematrixduo commented Apr 8, 2017

Same here, I tried to train a DCGAN model and found that freezing the discriminator with trainable=False only works if the discriminator does not contain any batch norm layer. If discriminator has batch norm layer, the output of the model on the same data input is changed even when dsicriminator.trainable=False. Hope there can be a fix :D @fchollet

@litesaber15
Copy link

I am a little confused about freezing BatchNormalization. It has gamma and beta parameters that are initialized with 1s and 0s respectively by default, and they're also trainable by default. How to freeze these? The API documentation and source code both say that BatchNormalization does not have trainable as a parameter.

However, weirdly, if I do pass trainable=False in BatchNormalization(), the number of trainable parameters drops to 0. What am I missing here?

@redsphinx
Copy link

redsphinx commented Apr 19, 2017

@Tokukawa I am also having this issue. I just tried setting mode=1 and it gives the following error: TypeError: The mode argument of BatchNormalization no longer exists. mode=1 and mode=2 are no longer supported.

@Tokukawa
Copy link

@redsphinx In keras 2 BatchNormalization has been rewritten from scratch. I didn't tried it yet,but I guess this issue is gone now.

@milsto
Copy link

milsto commented Apr 20, 2017

Passing trainable=False in BatchNormalization() will freeze the layer parameters in Keras 2 (tested with Keras 2.0.2). Seams that @Tokukawa is right, and I think this issue can be closed.

@waleedka
Copy link
Contributor

waleedka commented May 6, 2017

I believe it's still broken. This code uses Keras 2.0.3 and shows the problem.

import keras
from keras.layers.normalization import BatchNormalization
from keras.models import Sequential
import numpy as np

print("Version: ", keras.__version__)

# Basic model
model = Sequential()
model.add(BatchNormalization(input_shape=(2,)))
model.compile(loss='mse', optimizer='adam')

# Print weights and predictions before training.
X = np.random.normal(size=(1, 2))
print("Prediction before training: ", model.predict(X))
print("Weights before training: ", [[list(w) for w in l.get_weights()] for l in model.layers])


# Train on random output, but set all layers to Trainable=False
Y = np.random.normal(size=(1, 2))
for l in model.layers:
    l.trainable = False
model.fit(X, Y, verbose=0, epochs=10)

print("\n\nPrediction after training: ", model.predict(X))
print("Weights after training: ", [[list(w) for w in l.get_weights()] for l in model.layers])

Output:

Version:  2.0.3
Prediction before training:  [[ 0.99468702 -0.22452217]]
Weights before training:  [[[1.0, 1.0], [0.0, 0.0], [0.0, 0.0], [1.0, 1.0]]]


Prediction after training:  [[ 0.93589312 -0.2035163 ]]
Weights after training:  [[[1.0, 1.0], [-0.0099943792, 0.0099907704], [0.095157444, -0.021479076], [0.90438205, 0.90438205]]]

As this shows, despite setting trainable = False, the weights and the output of the model are changed after training. If I'm missing something, I'd appreciate the hints.

@fchollet
Copy link
Collaborator

fchollet commented May 6, 2017

That's expected behavior.

despite setting trainable = False, the weights and the output of the model are changed after training

Trainable weights do not change. But batchnorm also maintains non-trainable weights, which are updated via layer updates (i.e. not through backprop): the mean and variance vectors.

@fchollet fchollet closed this as completed May 6, 2017
@waleedka
Copy link
Contributor

waleedka commented May 6, 2017

That's expected behavior.

@fchollet So what's the correct way to freeze a batch normalization layer? As in, to freeze both the trainable and non-trainable weights?

@fchollet
Copy link
Collaborator

fchollet commented May 6, 2017

@fchollet So what's the correct way to freeze a batch normalization layer? As in, to freeze both the trainable and non-trainable weights?

If you want to disable weight updates you can simply call the layer with the argument training=False in the functional API, which disables the Keras learning phase for this layer (e.g. the layer will always run in inference mode, even when training the model).

e.g.

x = BatchNormalization()(x, training=False)

You could also disable weight updates by manually setting the layer's attribute _per_input_updates to {}, but that's not part of the public API.

@waleedka
Copy link
Contributor

waleedka commented May 6, 2017

@fchollet Thanks! The unfortunate side-effect of having to pass training=False is that I can't switch a batch norm layer between trainable and untrainable like I do the other layer types. Instead, I have to re-build the model every time I need to change this property.

The "trainable" property makes it convenient to do multi-stage training in which I freeze some layers and train, then unfreeze all the layers and fine-tune. All without having to re-build the model. This flexibility and the consistency across the library is what makes Keras so cool. Having to treat batch norm layers differently breaks the consistency.

I'd encourage making "trainable" apply to any weights that change during training, whether the change is done through back prop or a running average.

@fchollet
Copy link
Collaborator

fchollet commented May 6, 2017

If you are doing fine-tuning and so on, then setting trainable to False is exactly what you want. The fact that BN will adapt to the statistics of your new data is precisely what you want.

@undo76
Copy link

undo76 commented Jul 24, 2017

@fchollet I am not sure if it makes sense to update the weights of the BatchNormalization layer with trainable=False. Maybe I am wrong, but as I understand, It is basically changing the distribution of the data and not allowing the network to adapt to it. Maybe this effect it is not so important while fine-tuning a model as the distribution has been already learnt and won't change too much, but in my particular case, if I add a BatchNormalization layer in the discriminator of a GAN, it prevents any learning at all.

@moshebou
Copy link

moshebou commented Aug 9, 2017

The only way I found to solve this issue was to set the momentum to 1. This will make sure that the BN moving_mean and moving_variance are not updated.
HOWEVER - you cannot update the momentum on-the-fly, since the TF model is created only when added to the network.
So:
Create two identical networks: net1 with momentum=0.99 and net2 with momentum=1.
Train net1, and get the BN mean and var trained.
Then, when you wish to not update the BN, do:
net2.set_weights(net1.get_weights())

@waleedka
Copy link
Contributor

waleedka commented Aug 9, 2017

The work-around I've been using is to subclass BatchNormalization as such:

class StaticBatchNormalization(BatchNormalization):
    def call(self, inputs, training=None):
        return super(StaticBatchNormalization, self).call(inputs, training=False)

And then I use StaticBatchNormalization instead of the standard BN layer for layers that I want to freeze. Basically what this does is pass training=False to the standard BN layer. I use this approach rather than simply passing this parameter to the standard layer because there are situations where I can't pass parameters to the layer, for example, when wrapping it with a TimeDistributed wrapper.

But just like @moshebou's solution, this makes it hard to switch layers from being tainable/untrainable dynamically.

@ghost
Copy link

ghost commented Aug 13, 2017

@moshebou I can't reproduce your approach. I compile d with SGD(lr=0, momentum=1, nestrov=False, decay=0) but it still trains. Any idea how to fix this?

@waleedka I'm sorry if this is a dumb question, but what does your code do? Do you wrap it around a BN layerer or do you use it instead? A link to where I could learn what .call does would also be appriciated! I'm new to python and I'm learning as I go.

@waleedka
Copy link
Contributor

@Pizzafarmer I updated my comment to be clearer about how and why I use this approach. And, regarding @moshebou method, I believe it's the BN's momentum not the optimizer's momentum.

@moshebou
Copy link

Yes, @waleedka is right, the BN's momentum should be set to 1.

...
net1.add(BatchNormalization(axis=-1, momentum=0.99))
...
net2.add(BatchNormalization(axis=-1, momentum=1))
## train net1
net1.fit_generator(...)
## copy weights from net1 to net2
net2.set_weights(net1.get_weights())
## train net2
net2.fit_generator(...)

This will take care of the BN.

@ghost
Copy link

ghost commented Aug 14, 2017

@moshebou @waleedka thank both of you so much! I got it to work! Sadly my generator is not learning at all, even after I freeze the BN layers for learn rates in between 10^10 and 10^-10 so my generator architecture is probably bad.

Do you have any idea where I could find information on how to design generators and how to get them to learn?

@moshebou
Copy link

@Pizzafarmer
I suggest to train the generator separately, ant try to achieve overfitting on a small train dataset. Once you see reasonable results from the generator, incorporate it into a GAN architecture.

@gajeshladhar
Copy link

gajeshladhar commented Aug 21, 2020

Use this Layer


class Normalization():
  def __init__(self):
    self.alpha=0
    self.beta=0
    self.total_mean=0
    self.total_std=0
    self.batch_mean=0
    self.batch_std=0
    self.start=0
    self.trainable=True
  def get_weights(self):
    return [self.alpha,self.beta,self.total_mean,self.total_std]
  def set_weights(self,weights):
    self.alpha,self.beta,self.total_mean,self.total_std=weights

  def __call__(self,X):
    self.X=X
    if self.start==0 :
      self.alpha=tf.Variable(np.random.random(X.shape[1:]))
      self.beta=tf.Variable(np.random.random(X.shape[1:]))

      self.total_mean=np.zeros(X.shape[1:])
      self.total_std=np.zeros(X.shape[1:])

      self.start=1
    
    if self.trainable==True :
      self.total_mean=0.9980*self.total_mean + 0.0020*(tf.reduce_mean(X,axis=0)).numpy()
      self.total_std=0.9980*self.total_std + 0.0020*(tf.math.reduce_std(X,axis=0)).numpy()


      self.X=(self.X-(tf.expand_dims(tf.reduce_mean(X,axis=0),axis=0)))/(tf.expand_dims(tf.math.reduce_std(X,axis=0),axis=0))
      self.X=(self.alpha*self.X+self.beta)

    else :
      self.X=(self.X-(self.total_mean))/(self.total_std)
      self.X=(self.alpha*self.X+self.beta)

    return self.X

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests