-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Setting learning_phase to 0 leads to extremely low accuracy #7177
Comments
I did a simple test to check whether the BatchNormalization was the culprit, and do not found evidence of it. Specifically, from keras.models import load_model, Sequential
from keras.layers import Dense, Flatten, Conv2D, MaxPooling2D, BatchNormalization
from keras.datasets import mnist
from keras import backend as K
import keras
import os
# input image dimensions
img_rows, img_cols = 28, 28
num_classes = 10
# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
if K.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
def get_trained_model(file_path):
if os.path.exists(file_path):
return load_model(file_path)
batch_size = 128
epochs = 12
model = Sequential()
model.add(Conv2D(16, kernel_size=(3, 3),
activation='relu',
input_shape=input_shape))
model.add(Conv2D(4, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
# Use batch normalization instead of Dropout to test it
model.add(BatchNormalization())
model.add(Flatten())
model.add(Dense(16, activation='relu'))
model.add(BatchNormalization())
model.add(Dense(num_classes, activation='softmax'))
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adadelta(),
metrics=['accuracy'])
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
verbose=2,
validation_data=(x_test, y_test))
model.save(file_path)
return model
model = get_trained_model('test.h5')
print('Before Load:', model.evaluate(x_test, y_test, verbose=0))
K.clear_session()
K.set_learning_phase(0)
model = load_model('test.h5')
print('After Load - learning_phase=0:', model.evaluate(x_test, y_test, verbose=0))
K.clear_session()
K.set_learning_phase(1)
model = load_model('test.h5')
print('After Load - learning_phase=1:', model.evaluate(x_test, y_test, verbose=0)) I get (TensorFlow, Keras 2.0.5)
|
@jorgecarleitao I think it is important to do fine-tuning on the model (freeze some of the layers) in order to reproduce the results. Note that the way I do fine-tuning here comes from the documentation. Could you add a couple more layers on your example and freeze part of the network? Alternatively you can run my snippet. I kept digging into this and I noticed that that when you freeze part of the model, the moving mean and variance of the frozen BatchNormalization layers keep updating their values (beta and gamma are not though). Perhaps @fchollet can clarify if this is a bug or an intended behaviour. This might actually be responsible for the problem. It is worth noting that if this is indeed a bug, it can heavily affect people who deploy their models on the live environment. |
Maybe related to #4762 As explained in my comment, I think it should be possible to freeze the BatchNormalization parameters when |
@jorgecarleitao I actually added more layers to your snippet and performed fine-tuning but I can't make it break. Still uncertain on whether this is caused by the BN layer. |
OK I believe I know what is the problem. It is not a bug, but a side-effect of the way we estimate the moving averages on BatchNormalization. The mean and variance of the training data that I use are different from the ones of the dataset used to train the ResNet50 (the effect is amplified by the fact I don't subtract the average pixel & flip the channel order but you can actually get the same result even if you do). Because the momentum on the BatchNormalization has a default value of 0.99, with only 5 iterations it does not converge quickly enough to the correct values for the moving mean and variance. This is not obvious during training when the learning_phase is 1 because BN uses the mean/variance of the batch. Nevertheless when we set learning_phase to 0, the incorrect mean/variance values which are learned during training significantly affect the accuracy. The reason why @jorgecarleitao 's snippet does not reproduce the problem is because he trains the model from scratch rather than using pre-trained weights. There are two ways to demonstrate that this is the root of the problem: 1. More iterations On my original snippet, reduce the size of the batch from 32 to 16 (to perform more updates per epoch) and increase the number of epochs from 5 to 250. This way the moving average and variance will converge to the correct values. Output:
2. Change the momentum of BatchNormalization Keep the number of iterations fixed but change the momentum of the BatchNormalization layer to update more aggressively the rolling mean and variance (not recommended for production models). Note that changing the momentum field after the model has been initialised will not have any effect (the graph on Tensorflow has already been constructed using this value), so we use a hacky patch to demonstrate the case. On my original snippet, add the following patch between reading the base_model and defining the new layers: # ....
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=input_shape)
# PATCH MOMENTUM - START
import json
conf = json.loads(base_model.to_json())
for l in conf['config']['layers']:
if l['class_name'] == 'BatchNormalization':
l['config']['momentum'] = 0.5
m = Model.from_config(conf['config'])
for l in base_model.layers:
m.get_layer(l.name).set_weights(l.get_weights())
base_model = m
# PATCH MOMENTUM - END
x = base_model.output
# .... Output:
Hope this helps others that face similar issues. |
Hi @datumbox , |
Loading a persisted model and setting the learning_phase=0 reduces the accuracy from 100% to 50% in binary classification problem.
The below script contains a simple example that reproduces the problem on Keras 2.0.5 and TensorFlow 1.2 (Python 2.7, Ubuntu 14.04, Nvidia Quadro K2200 GPU). I use an extremely small dataset and I intentionally overfit the model.
Snippet:
Output:
As you can see above, I explicitly save the model, clear the session and load it again. This is important for reproducing the problem. I don't believe that there is an issue on the persistence mechanism of Keras as the weights before and after the load() seem the same.
Since ResNet50 does not contain any Dropout layer, I believe the problem is caused by the BatchNormalization layers. As far as I see on Keras source, during training we use the sample mean/variance of the mini-batch while during testing we use the rolling mean/variance.
Any thoughts from Keras contributors? I'm happy to provide more info or investigate further.
@fchollet Could you provide any hint/pointers where to look next?
Check that you are up-to-date with the master branch of Keras. You can update with:
pip install git+git://github.com/fchollet/keras.git --upgrade --no-deps
If running on TensorFlow, check that you are up-to-date with the latest version. The installation instructions can be found here.
If running on Theano, check that you are up-to-date with the master branch of Theano. You can update with:
pip install git+git://github.com/Theano/Theano.git --upgrade --no-deps
Provide a link to a GitHub Gist of a Python script that can reproduce your issue (or just copy the script here if it is short).
The text was updated successfully, but these errors were encountered: