-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
trainable flag does not work for batch normalization layer #4762
Comments
Please reformulate your question in a clearer fashion. |
I think I'm experiencing the same issue. |
As it seems like a problem with calling self.add_update(), not using mode=0 could be a temporary workaround. |
You may need to set the learning phase to testing (0): e.g.,
per https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html |
Actually, what I wanted was to prevent some part of the network from being trained, and the part contained BatchNormalization layers. |
A little bit more of context from keras.layers import normalization
from keras.models import Sequential
import numpy as np
model0 = Sequential()
norm_m0 = normalization.BatchNormalization(input_shape=(10,), momentum=0.8)
model0.add(norm_m0)
model0.summary()
model1 = Sequential()
norm_m1 = normalization.BatchNormalization(input_shape=(10,), momentum=0.8)
model1.add(norm_m1)
for layer in model1.layers:
layer.trainable = False
model1.compile(loss='mse', optimizer='sgd')
print("Shape batch normalization: {}.".format(len(model0.layers[-1].get_weights())))
print("Before training")
print([np.array_equal(w0, w1)for w0,w1 in zip(model0.layers[-1].get_weights(), model1.layers[-1].get_weights())])
X = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10))
model1.fit(X, X, nb_epoch=4, verbose=0)
print("After training")
print([np.array_equal(w0,w1)for w0,w1 in zip(model0.layers[-1].get_weights(), model1.layers[-1].get_weights())]) Shape batch normalization: 4.
Before training
[True, True, True, True]
After training
[True, True, False, False] We instantiate two models. from keras.layers import normalization
from keras.models import Sequential
import numpy as np
model0 = Sequential()
norm_m0 = normalization.BatchNormalization(mode=1, input_shape=(10,), momentum=0.8)
model0.add(norm_m0)
model0.summary()
model1 = Sequential()
norm_m1 = normalization.BatchNormalization(mode=1, input_shape=(10,), momentum=0.8)
model1.add(norm_m1)
for layer in model1.layers:
layer.trainable = False
model1.compile(loss='mse', optimizer='sgd')
print("Shape batch normalization: {}.".format(len(model0.layers[-1].get_weights())))
print("Before training")
print([np.array_equal(w0, w1)for w0,w1 in zip(model0.layers[-1].get_weights(), model1.layers[-1].get_weights())])
X = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10))
model1.fit(X, X, nb_epoch=4, verbose=0)
print("After training")
print([np.array_equal(w0,w1)for w0,w1 in zip(model0.layers[-1].get_weights(), model1.layers[-1].get_weights())]) Shape batch normalization: 4.
Before training
[True, True, True, True]
After training
[True, True, True, True] |
Same here, I tried to train a DCGAN model and found that freezing the discriminator with trainable=False only works if the discriminator does not contain any batch norm layer. If discriminator has batch norm layer, the output of the model on the same data input is changed even when dsicriminator.trainable=False. Hope there can be a fix :D @fchollet |
I am a little confused about freezing BatchNormalization. It has gamma and beta parameters that are initialized with 1s and 0s respectively by default, and they're also trainable by default. How to freeze these? The API documentation and source code both say that BatchNormalization does not have However, weirdly, if I do pass |
@Tokukawa I am also having this issue. I just tried setting |
@redsphinx In keras 2 |
Passing |
I believe it's still broken. This code uses Keras 2.0.3 and shows the problem. import keras
from keras.layers.normalization import BatchNormalization
from keras.models import Sequential
import numpy as np
print("Version: ", keras.__version__)
# Basic model
model = Sequential()
model.add(BatchNormalization(input_shape=(2,)))
model.compile(loss='mse', optimizer='adam')
# Print weights and predictions before training.
X = np.random.normal(size=(1, 2))
print("Prediction before training: ", model.predict(X))
print("Weights before training: ", [[list(w) for w in l.get_weights()] for l in model.layers])
# Train on random output, but set all layers to Trainable=False
Y = np.random.normal(size=(1, 2))
for l in model.layers:
l.trainable = False
model.fit(X, Y, verbose=0, epochs=10)
print("\n\nPrediction after training: ", model.predict(X))
print("Weights after training: ", [[list(w) for w in l.get_weights()] for l in model.layers]) Output:
As this shows, despite setting trainable = False, the weights and the output of the model are changed after training. If I'm missing something, I'd appreciate the hints. |
That's expected behavior.
Trainable weights do not change. But batchnorm also maintains non-trainable weights, which are updated via layer updates (i.e. not through backprop): the mean and variance vectors. |
@fchollet So what's the correct way to freeze a batch normalization layer? As in, to freeze both the trainable and non-trainable weights? |
If you want to disable weight updates you can simply call the layer with the argument e.g.
You could also disable weight updates by manually setting the layer's attribute |
@fchollet Thanks! The unfortunate side-effect of having to pass training=False is that I can't switch a batch norm layer between trainable and untrainable like I do the other layer types. Instead, I have to re-build the model every time I need to change this property. The "trainable" property makes it convenient to do multi-stage training in which I freeze some layers and train, then unfreeze all the layers and fine-tune. All without having to re-build the model. This flexibility and the consistency across the library is what makes Keras so cool. Having to treat batch norm layers differently breaks the consistency. I'd encourage making "trainable" apply to any weights that change during training, whether the change is done through back prop or a running average. |
If you are doing fine-tuning and so on, then setting |
@fchollet I am not sure if it makes sense to update the weights of the BatchNormalization layer with trainable=False. Maybe I am wrong, but as I understand, It is basically changing the distribution of the data and not allowing the network to adapt to it. Maybe this effect it is not so important while fine-tuning a model as the distribution has been already learnt and won't change too much, but in my particular case, if I add a BatchNormalization layer in the discriminator of a GAN, it prevents any learning at all. |
The only way I found to solve this issue was to set the momentum to 1. This will make sure that the BN moving_mean and moving_variance are not updated. |
The work-around I've been using is to subclass BatchNormalization as such:
And then I use StaticBatchNormalization instead of the standard BN layer for layers that I want to freeze. Basically what this does is pass training=False to the standard BN layer. I use this approach rather than simply passing this parameter to the standard layer because there are situations where I can't pass parameters to the layer, for example, when wrapping it with a TimeDistributed wrapper. But just like @moshebou's solution, this makes it hard to switch layers from being tainable/untrainable dynamically. |
@moshebou I can't reproduce your approach. I compile d with @waleedka I'm sorry if this is a dumb question, but what does your code do? Do you wrap it around a BN layerer or do you use it instead? A link to where I could learn what .call does would also be appriciated! I'm new to python and I'm learning as I go. |
@Pizzafarmer I updated my comment to be clearer about how and why I use this approach. And, regarding @moshebou method, I believe it's the BN's momentum not the optimizer's momentum. |
Yes, @waleedka is right, the BN's momentum should be set to 1.
This will take care of the BN. |
@moshebou @waleedka thank both of you so much! I got it to work! Sadly my generator is not learning at all, even after I freeze the BN layers for learn rates in between 10^10 and 10^-10 so my generator architecture is probably bad. Do you have any idea where I could find information on how to design generators and how to get them to learn? |
@Pizzafarmer |
Use this Layer
|
I work with keras 1.0.2 with tensorflow.
I try to freeze some layers in the network, it works well for convolutions and FC but not for the batch normalization layer. I print the weights of the layer before and after one epoch and I see a changes. any ideas ?
The text was updated successfully, but these errors were encountered: