Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change BN layer to use moving mean/var if frozen #9965

Closed
wants to merge 10 commits into from

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Apr 17, 2018

During fine-tuning, if a Batch Normalization layer is frozen it uses the mini-batch statistics. I believe this is incorrect and it can lead to reduced accuracy especially when we use Transfer learning. A better approach in this case would be to use the values of the moving mean and variance.

Changes on this PR:
In this PR I update the Batch Normalization layer to use the learned statistics if frozen during training. This is achieved by making the trainable flag part of the computational graph and by depending the behavior of the BN not only on the learning_phase but also on the value of the trainable property.

Brief explanation:
Assume we use one of the pre-trained CNNs of Keras and we want to fine-tune it. Unfortunately, we get no guarantees that the mean and variance of our new dataset inside the BN layers will be similar to the ones of the original dataset. As a result, if we fine-tune the top layers, their weights will be adjusted to the mean/variance of the new dataset. Nevertheless, during inference the top layers will receive data which are scaled using the mean/variance of the original dataset. This discrepancy can lead to reduced accuracy.

I understand that this is a significant change that requires thorough review. To faciliate the situation I've documented why making such a change is important and provided detailed comparisons before and after applying the patch on my blog.

EDIT: Since the fix was not merged on master, I maintain unofficial patches available for Keras 2.1.6, Keras 2.2.2 and Keras 2.2.4.

@fchollet
Copy link
Collaborator

fchollet commented Apr 19, 2018

Thanks for the effort.

You are misunderstanding the meaning of the "trainable" property of layers. Historically, it has initially meant "this layer should not be trainable, i.e. the weights of this layer should not be updated during backprop (specifically layer.trainable_weights should be empty)". Then it has been extend to mean "the state of the layer should be frozen during training" (which means that, in addition to the previous definition, layer updates are not run).

What you want is a BN layer in inference mode. There is an argument to control training/inference mode in BN (and other layers): it's the training argument in call (boolean).

What you want is:

x = BatchNormalization()(y, training=False)

For fine-tuning, you could do something like:

# Set up inference-mode base
K.set_learning_phase(0)
inputs = Input(...)
x = layer1(...)(inputs)
x = layer2(...)(x)
...
x = layerN(...)(x)

# Add training-mode layers
K.set_learning_phase(1)
x = layerNp1(...)(x)
x = layerNp2(...)(x)

@fchollet fchollet closed this Apr 19, 2018
@datumbox
Copy link
Contributor Author

datumbox commented Apr 19, 2018

Hi @fchollet,

First of all thanks for taking the time to review and respond. I was aware that this is a significant change in the default behaviour and that there would be debate. :)

I understand that your main concern is around the semantic meaning of the trainable property and how it is being used in this PR. I agree that semantically the training parameter that you proposed is closer to what I do, nevertheless this parameter can't change after the network definition. For instance when you use one of the pre-trained models of keras or when you load a persisted model you have no control over this variable. Would you be open to discuss a solution that would make the training variable changeable after the network definition (or perhaps another property)? If you are open to this, I could update my PR to reflect the agreed behaviour.

Concerning your second recommendation of updating the learning_phase as the network is defined, I see two limitations:

  1. Again this will work only if the network is defined based on code. It will not work for the pretrained models of Keras or when a model is loaded from disk. The latter is quite important; models are trained in multiple rounds usually after restoring them from checkpoints.
  2. After setting the learning_phase(1) in your example, the learning_phase will be static for the remaining of the session. This will overwrite all the nice mechanisms that keras has for switching between phases depending on whether it trains or predicts. Thus if we call fit() with validation data, the model will predict while being in training mode.

I'm not sure if you had a look on the blog post (it is understandably a bit long), but you can see how significant perfomance boost you get by making it possible to set the BN in inference mode. Without this the trainable layers after the BNs adjust their weights based on input that has different scale (comparing to inference). I hope, we can re-open this PR; I'm happy to update it until it satisfies the semantic definitions.

Cheers!

@fchollet
Copy link
Collaborator

Again, there is an existing API that does exactly what you want: the training argument in call. There is no point in having two differently name APIs that do the exact same thing. layer.trainable = False is not what you need, therefore don't use it.

Additionally, your proposed PR adds a computational overhead (which might amount to a ~5% slowdown for a BN-heavy model like InceptionV3) to every single convnet that uses BN, fine-tuning or not. This is a heavy price to pay for supporting an incrementally simpler UX (disputable) for a very specific use case.

For instance when you use one of the pre-trained models of keras or when you load a persisted model you have no control over this variable.

Typically if you want to heavily modify an existing model, rather than merely use it in inference mode, you should have access to the code for the model.

But even if you don't, you can still do your style of fine-tuning in this case:

  • set learning phase to 0
  • load model
  • retrieve features you want to train on
  • set learning phase to 1
  • add new layers on top
  • optionally load weights from initial model layers to corresponding new layers
  • train

@datumbox
Copy link
Contributor Author

@fchollet My main point is that the training argument can't be changed after model definition, so the existing API does not cover this valid case. I don't argue that there are workarounds, but they are hacky/non-elegant and the default behaviour leads to much confusion to users. Interesting what you mention about the 5% slow down, I would love to see the benchmarks; perhaps it can be resolved. Finally something you don't address here is whether this discrepancy in the scaling makes sense (theoretically or otherwise) and whether the accuracy decrease is worth it.

At any case, let's agree we disagree. I do hope though that you will revise your decision on the future, as it happened with the update of the mini-batch statistics on the BN.

@datumbox datumbox mentioned this pull request Apr 19, 2018
@fchollet
Copy link
Collaborator

fchollet commented Apr 19, 2018

I would love to see the benchmarks

This is based on something I've observed in the past for InceptionV3 with static learning phase vs. with dynamic learning phase. Only difference between the two settings is cond ops. Control flow seems pretty expensive, especially on GPU. Your PR adds the exact same number cond ops, so I would expect the same overhead.

@datumbox
Copy link
Contributor Author

datumbox commented Apr 20, 2018

Thanks for the clarifying that you are referring to a different benchmark and not to something you ran on this PR. I can't comment on the results without seeing them but when I ran comparisons on CIFAR10 the time difference was negligible (current branch: 4216 secs vs patched: 4251 secs); both ran on GPUs on the same server. Note that the snippet that I used (and listed on my article) comes from Keras' documentation on how to fine-tune a network.

Admittedly the above measurements are single point estimates but especially the 5 point accuracy increase I report is consistent with what I've been observing for almost a year while applying workarounds (first time I reported this is on #7177). I don't know if the speed is currently your main concern for reopening this but I would say that this is unlikely to affect the majority of the users of Keras. This is because by default the Learning Phase is dynamic and the training argument of call is None. This will force the in_train_phase method on backend to use a switch statement that depends on learning phase, so in a sense the "if" statement is already there.

At any case I don't insist that it should me who changes this or that my current solution is the one we should use. I'm just raising a valid use case that is taken directly from Keras' documentation on how fine-tuning is performed. Currently there is no straightforward way to do what I describe (the current API doesn't cover it), nevertheless if you provide specific guidelines on what tickboxes the update should check it would be useful. Or perhaps some other longtime contributor of the BatchNormalization layer has an opinion or can offer a more elegant solution on this? @ozabluda @taehoonlee @farizrahman4u @Dref360

@ozabluda
Copy link
Contributor

Sorry for late reply, still trying to understand the issues. For example, I am trying to understand if this is related at all to #9214

@ahundt
Copy link
Contributor

ahundt commented Apr 23, 2018

What sort of batch sizes were you using in your linked experiments?

Some datasets are only viable with very small batch sizes of 1-4, like with image segmentation on a GPU with 8GB of memory. After briefly skimming this diff, I think the documentation would need to be updated to clearly delineate the different modes and when/why each should typically be chosen. In my case the current frozen behavior improved performance quite a lot over the previous behavior in which mean/var could shift when trainable=False, so I'm a bit hesitant about this though I'll reiterate I haven't reviewed what's happening in full detail.

Here is a PR with some past discussion on BN #8616

@datumbox
Copy link
Contributor Author

datumbox commented Apr 23, 2018

@ozabluda First of all thank you for spending time on this. I wish I had provided on my PR the example that you posted on the issue #9214; perhaps this would have built a stronger case for this patch. What you showed on your post is exactly what I've been observing on real-world non-opensource datasets for the last year (close to 100% accuracy on training mode and 50% during inference on the same dataset and on similar validation sets). As @fchollet said the are lots of hacks that can help you avoid it but none of them should have been necessary.

Based on the code you provided, I'm 100% certain you are being bitten by the behaviour of the BN layer that I'm trying to fix in this PR. In a nutshell, during training mode the frozen BN layers are scaled with different statistics than in inference mode. There is absolutely no theoretical foundation to support this behaviour. As a result, this can have devastating effects when you try to deploy the model or when you try to validate its accuracy. I am certain that the majority of people who face this believe they have overfitted the model while in reality this is just a side-effect of how Keras implements the Batch Normalization layer.

So let's test your example on my branch of Keras where the BN layer is patched:

pip install -U --force-reinstall --no-dependencies git+https://github.com/datumbox/keras@bugfix/trainable_bn

Below I run your code for ResNet50. As you can see the problem that you report is fixed once the BN behaviour is changed:

Epoch 1/100
50/50 [==============================] - 19s 387ms/step - loss: 0.8738 - acc: 0.5000 - val_loss: 1.3021 - val_acc: 0.5000
Epoch 2/100
50/50 [==============================] - 18s 367ms/step - loss: 1.3021 - acc: 0.5000 - val_loss: 0.9412 - val_acc: 0.5000
Epoch 3/100
50/50 [==============================] - 18s 363ms/step - loss: 0.9412 - acc: 0.5000 - val_loss: 0.6904 - val_acc: 0.5000
Epoch 4/100
50/50 [==============================] - 18s 364ms/step - loss: 0.6904 - acc: 0.5000 - val_loss: 0.9428 - val_acc: 0.5000
Epoch 5/100
50/50 [==============================] - 18s 361ms/step - loss: 0.9428 - acc: 0.5000 - val_loss: 0.9180 - val_acc: 0.5000
Epoch 6/100
50/50 [==============================] - 20s 401ms/step - loss: 0.9180 - acc: 0.5000 - val_loss: 0.7111 - val_acc: 0.5000
Epoch 7/100
50/50 [==============================] - 21s 415ms/step - loss: 0.7111 - acc: 0.5000 - val_loss: 0.6802 - val_acc: 0.5200
Epoch 8/100
50/50 [==============================] - 20s 406ms/step - loss: 0.6802 - acc: 0.5200 - val_loss: 0.8039 - val_acc: 0.5000
Epoch 9/100
50/50 [==============================] - 20s 391ms/step - loss: 0.8039 - acc: 0.5000 - val_loss: 0.8075 - val_acc: 0.5000
Epoch 10/100
50/50 [==============================] - 21s 425ms/step - loss: 0.8075 - acc: 0.5000 - val_loss: 0.6963 - val_acc: 0.5000
Epoch 11/100
50/50 [==============================] - 21s 417ms/step - loss: 0.6963 - acc: 0.5000 - val_loss: 0.6406 - val_acc: 0.7000
Epoch 12/100
50/50 [==============================] - 21s 419ms/step - loss: 0.6406 - acc: 0.7000 - val_loss: 0.7017 - val_acc: 0.5000
Epoch 13/100
50/50 [==============================] - 21s 425ms/step - loss: 0.7017 - acc: 0.5000 - val_loss: 0.7408 - val_acc: 0.5000
Epoch 14/100
50/50 [==============================] - 22s 441ms/step - loss: 0.7408 - acc: 0.5000 - val_loss: 0.6895 - val_acc: 0.5000
Epoch 15/100
50/50 [==============================] - 22s 432ms/step - loss: 0.6895 - acc: 0.5000 - val_loss: 0.6267 - val_acc: 0.7200
Epoch 16/100
50/50 [==============================] - 23s 460ms/step - loss: 0.6267 - acc: 0.7200 - val_loss: 0.6376 - val_acc: 0.5600
Epoch 17/100
50/50 [==============================] - 22s 439ms/step - loss: 0.6376 - acc: 0.5600 - val_loss: 0.6775 - val_acc: 0.5400
Epoch 18/100
50/50 [==============================] - 23s 456ms/step - loss: 0.6775 - acc: 0.5400 - val_loss: 0.6675 - val_acc: 0.5400
Epoch 19/100
50/50 [==============================] - 21s 414ms/step - loss: 0.6675 - acc: 0.5400 - val_loss: 0.6209 - val_acc: 0.6000
Epoch 20/100
50/50 [==============================] - 19s 375ms/step - loss: 0.6209 - acc: 0.6000 - val_loss: 0.6055 - val_acc: 0.7400
Epoch 21/100
50/50 [==============================] - 18s 367ms/step - loss: 0.6055 - acc: 0.7400 - val_loss: 0.6309 - val_acc: 0.5800
Epoch 22/100
50/50 [==============================] - 18s 370ms/step - loss: 0.6309 - acc: 0.5800 - val_loss: 0.6392 - val_acc: 0.5600
Epoch 23/100
50/50 [==============================] - 18s 369ms/step - loss: 0.6392 - acc: 0.5600 - val_loss: 0.6111 - val_acc: 0.6400
Epoch 24/100
50/50 [==============================] - 19s 390ms/step - loss: 0.6111 - acc: 0.6400 - val_loss: 0.5890 - val_acc: 0.7800
Epoch 25/100
50/50 [==============================] - 20s 394ms/step - loss: 0.5890 - acc: 0.7800 - val_loss: 0.5990 - val_acc: 0.6200
Epoch 26/100
50/50 [==============================] - 22s 445ms/step - loss: 0.5990 - acc: 0.6200 - val_loss: 0.6105 - val_acc: 0.5800
Epoch 27/100
50/50 [==============================] - 21s 413ms/step - loss: 0.6105 - acc: 0.5800 - val_loss: 0.5961 - val_acc: 0.6000
Epoch 28/100
50/50 [==============================] - 19s 388ms/step - loss: 0.5961 - acc: 0.6000 - val_loss: 0.5759 - val_acc: 0.8000
Epoch 29/100
50/50 [==============================] - 20s 391ms/step - loss: 0.5759 - acc: 0.8000 - val_loss: 0.5767 - val_acc: 0.7400
Epoch 30/100
50/50 [==============================] - 19s 372ms/step - loss: 0.5767 - acc: 0.7400 - val_loss: 0.5857 - val_acc: 0.7400
Epoch 31/100
50/50 [==============================] - 22s 433ms/step - loss: 0.5857 - acc: 0.7400 - val_loss: 0.5785 - val_acc: 0.7600
Epoch 32/100
50/50 [==============================] - 19s 373ms/step - loss: 0.5785 - acc: 0.7600 - val_loss: 0.5627 - val_acc: 0.7800
Epoch 33/100
50/50 [==============================] - 21s 417ms/step - loss: 0.5627 - acc: 0.7800 - val_loss: 0.5597 - val_acc: 0.7800
Epoch 34/100
50/50 [==============================] - 21s 422ms/step - loss: 0.5597 - acc: 0.7800 - val_loss: 0.5651 - val_acc: 0.7000
Epoch 35/100
50/50 [==============================] - 18s 365ms/step - loss: 0.5651 - acc: 0.7000 - val_loss: 0.5606 - val_acc: 0.7200
Epoch 36/100
50/50 [==============================] - 18s 362ms/step - loss: 0.5606 - acc: 0.7200 - val_loss: 0.5488 - val_acc: 0.8000
Epoch 37/100
50/50 [==============================] - 18s 361ms/step - loss: 0.5488 - acc: 0.8000 - val_loss: 0.5449 - val_acc: 0.7800
Epoch 38/100
50/50 [==============================] - 18s 361ms/step - loss: 0.5449 - acc: 0.7800 - val_loss: 0.5473 - val_acc: 0.8000
Epoch 39/100
50/50 [==============================] - 18s 361ms/step - loss: 0.5473 - acc: 0.8000 - val_loss: 0.5433 - val_acc: 0.8000
Epoch 40/100
50/50 [==============================] - 18s 368ms/step - loss: 0.5433 - acc: 0.8000 - val_loss: 0.5344 - val_acc: 0.8000
Epoch 41/100
50/50 [==============================] - 19s 372ms/step - loss: 0.5344 - acc: 0.8000 - val_loss: 0.5311 - val_acc: 0.8600
Epoch 42/100
50/50 [==============================] - 18s 367ms/step - loss: 0.5311 - acc: 0.8600 - val_loss: 0.5318 - val_acc: 0.7800
Epoch 43/100
50/50 [==============================] - 18s 366ms/step - loss: 0.5318 - acc: 0.7800 - val_loss: 0.5278 - val_acc: 0.7800
Epoch 44/100
50/50 [==============================] - 18s 367ms/step - loss: 0.5278 - acc: 0.7800 - val_loss: 0.5208 - val_acc: 0.8800
Epoch 45/100
50/50 [==============================] - 18s 363ms/step - loss: 0.5208 - acc: 0.8800 - val_loss: 0.5181 - val_acc: 0.8200
Epoch 46/100
50/50 [==============================] - 18s 367ms/step - loss: 0.5181 - acc: 0.8200 - val_loss: 0.5175 - val_acc: 0.8200
Epoch 47/100
50/50 [==============================] - 19s 372ms/step - loss: 0.5175 - acc: 0.8200 - val_loss: 0.5131 - val_acc: 0.8400
Epoch 48/100
50/50 [==============================] - 19s 372ms/step - loss: 0.5131 - acc: 0.8400 - val_loss: 0.5075 - val_acc: 0.8600
Epoch 49/100
50/50 [==============================] - 19s 384ms/step - loss: 0.5075 - acc: 0.8600 - val_loss: 0.5053 - val_acc: 0.9000
Epoch 50/100
50/50 [==============================] - 19s 382ms/step - loss: 0.5053 - acc: 0.9000 - val_loss: 0.5035 - val_acc: 0.8400
Epoch 51/100
50/50 [==============================] - 18s 369ms/step - loss: 0.5035 - acc: 0.8400 - val_loss: 0.4989 - val_acc: 0.9000
Epoch 52/100
50/50 [==============================] - 20s 394ms/step - loss: 0.4989 - acc: 0.9000 - val_loss: 0.4944 - val_acc: 0.8800
Epoch 53/100
50/50 [==============================] - 19s 372ms/step - loss: 0.4944 - acc: 0.8800 - val_loss: 0.4920 - val_acc: 0.8800
Epoch 54/100
50/50 [==============================] - 18s 367ms/step - loss: 0.4920 - acc: 0.8800 - val_loss: 0.4890 - val_acc: 0.8800
Epoch 55/100
50/50 [==============================] - 19s 371ms/step - loss: 0.4890 - acc: 0.8800 - val_loss: 0.4845 - val_acc: 0.9000
Epoch 56/100
50/50 [==============================] - 18s 369ms/step - loss: 0.4845 - acc: 0.9000 - val_loss: 0.4811 - val_acc: 0.8800
Epoch 57/100
50/50 [==============================] - 18s 362ms/step - loss: 0.4811 - acc: 0.8800 - val_loss: 0.4792 - val_acc: 0.9000
Epoch 58/100
50/50 [==============================] - 18s 367ms/step - loss: 0.4792 - acc: 0.9000 - val_loss: 0.4759 - val_acc: 0.9000
Epoch 59/100
50/50 [==============================] - 18s 368ms/step - loss: 0.4759 - acc: 0.9000 - val_loss: 0.4721 - val_acc: 0.8800
Epoch 60/100
50/50 [==============================] - 18s 366ms/step - loss: 0.4721 - acc: 0.8800 - val_loss: 0.4695 - val_acc: 0.9200
Epoch 61/100
50/50 [==============================] - 18s 370ms/step - loss: 0.4695 - acc: 0.9200 - val_loss: 0.4670 - val_acc: 0.9000
Epoch 62/100
50/50 [==============================] - 18s 368ms/step - loss: 0.4670 - acc: 0.9000 - val_loss: 0.4634 - val_acc: 0.9200
Epoch 63/100
50/50 [==============================] - 22s 433ms/step - loss: 0.4634 - acc: 0.9200 - val_loss: 0.4602 - val_acc: 0.9200
Epoch 64/100
50/50 [==============================] - 19s 370ms/step - loss: 0.4602 - acc: 0.9200 - val_loss: 0.4578 - val_acc: 0.9200
Epoch 65/100
50/50 [==============================] - 19s 374ms/step - loss: 0.4578 - acc: 0.9200 - val_loss: 0.4548 - val_acc: 0.9200
Epoch 66/100
50/50 [==============================] - 19s 383ms/step - loss: 0.4548 - acc: 0.9200 - val_loss: 0.4515 - val_acc: 0.9400
Epoch 67/100
50/50 [==============================] - 20s 393ms/step - loss: 0.4515 - acc: 0.9400 - val_loss: 0.4488 - val_acc: 0.9200
Epoch 68/100
50/50 [==============================] - 19s 373ms/step - loss: 0.4488 - acc: 0.9200 - val_loss: 0.4462 - val_acc: 0.9200
Epoch 69/100
50/50 [==============================] - 19s 373ms/step - loss: 0.4462 - acc: 0.9200 - val_loss: 0.4431 - val_acc: 0.9400
Epoch 70/100
50/50 [==============================] - 18s 364ms/step - loss: 0.4431 - acc: 0.9400 - val_loss: 0.4402 - val_acc: 0.9400
Epoch 71/100
50/50 [==============================] - 18s 366ms/step - loss: 0.4402 - acc: 0.9400 - val_loss: 0.4376 - val_acc: 0.9800
Epoch 72/100
50/50 [==============================] - 19s 370ms/step - loss: 0.4376 - acc: 0.9800 - val_loss: 0.4347 - val_acc: 0.9800
Epoch 73/100
50/50 [==============================] - 18s 369ms/step - loss: 0.4347 - acc: 0.9800 - val_loss: 0.4317 - val_acc: 0.9400
Epoch 74/100
50/50 [==============================] - 18s 369ms/step - loss: 0.4317 - acc: 0.9400 - val_loss: 0.4291 - val_acc: 0.9400
Epoch 75/100
50/50 [==============================] - 19s 372ms/step - loss: 0.4291 - acc: 0.9400 - val_loss: 0.4264 - val_acc: 0.9400
Epoch 76/100
50/50 [==============================] - 18s 369ms/step - loss: 0.4264 - acc: 0.9400 - val_loss: 0.4235 - val_acc: 0.9400
Epoch 77/100
50/50 [==============================] - 19s 376ms/step - loss: 0.4235 - acc: 0.9400 - val_loss: 0.4208 - val_acc: 0.9600
Epoch 78/100
50/50 [==============================] - 19s 377ms/step - loss: 0.4208 - acc: 0.9600 - val_loss: 0.4182 - val_acc: 0.9800
Epoch 79/100
50/50 [==============================] - 19s 381ms/step - loss: 0.4182 - acc: 0.9800 - val_loss: 0.4154 - val_acc: 0.9600
Epoch 80/100
50/50 [==============================] - 19s 370ms/step - loss: 0.4154 - acc: 0.9600 - val_loss: 0.4127 - val_acc: 0.9400
Epoch 81/100
50/50 [==============================] - 18s 369ms/step - loss: 0.4127 - acc: 0.9400 - val_loss: 0.4101 - val_acc: 0.9400
Epoch 82/100
50/50 [==============================] - 19s 371ms/step - loss: 0.4101 - acc: 0.9400 - val_loss: 0.4075 - val_acc: 0.9400
Epoch 83/100
50/50 [==============================] - 18s 364ms/step - loss: 0.4075 - acc: 0.9400 - val_loss: 0.4048 - val_acc: 0.9600
Epoch 84/100
50/50 [==============================] - 18s 365ms/step - loss: 0.4048 - acc: 0.9600 - val_loss: 0.4022 - val_acc: 0.9800
Epoch 85/100
50/50 [==============================] - 18s 367ms/step - loss: 0.4022 - acc: 0.9800 - val_loss: 0.3996 - val_acc: 0.9800
Epoch 86/100
50/50 [==============================] - 18s 368ms/step - loss: 0.3996 - acc: 0.9800 - val_loss: 0.3970 - val_acc: 0.9600
Epoch 87/100
50/50 [==============================] - 18s 370ms/step - loss: 0.3970 - acc: 0.9600 - val_loss: 0.3945 - val_acc: 0.9600
Epoch 88/100
50/50 [==============================] - 18s 367ms/step - loss: 0.3945 - acc: 0.9600 - val_loss: 0.3919 - val_acc: 0.9600
Epoch 89/100
50/50 [==============================] - 18s 366ms/step - loss: 0.3919 - acc: 0.9600 - val_loss: 0.3894 - val_acc: 0.9600
Epoch 90/100
50/50 [==============================] - 18s 366ms/step - loss: 0.3894 - acc: 0.9600 - val_loss: 0.3869 - val_acc: 0.9800
Epoch 91/100
50/50 [==============================] - 19s 371ms/step - loss: 0.3869 - acc: 0.9800 - val_loss: 0.3844 - val_acc: 0.9800
Epoch 92/100
50/50 [==============================] - 18s 368ms/step - loss: 0.3844 - acc: 0.9800 - val_loss: 0.3819 - val_acc: 0.9800
Epoch 93/100
50/50 [==============================] - 18s 368ms/step - loss: 0.3819 - acc: 0.9800 - val_loss: 0.3795 - val_acc: 0.9800
Epoch 94/100
50/50 [==============================] - 18s 366ms/step - loss: 0.3795 - acc: 1.0000 - val_loss: 0.3770 - val_acc: 1.0000
Epoch 95/100
50/50 [==============================] - 18s 369ms/step - loss: 0.3770 - acc: 1.0000 - val_loss: 0.3746 - val_acc: 1.0000
Epoch 96/100
50/50 [==============================] - 18s 366ms/step - loss: 0.3746 - acc: 1.0000 - val_loss: 0.3722 - val_acc: 1.0000
Epoch 97/100
50/50 [==============================] - 18s 368ms/step - loss: 0.3722 - acc: 1.0000 - val_loss: 0.3698 - val_acc: 1.0000
Epoch 98/100
50/50 [==============================] - 18s 366ms/step - loss: 0.3698 - acc: 1.0000 - val_loss: 0.3674 - val_acc: 1.0000
Epoch 99/100
50/50 [==============================] - 18s 367ms/step - loss: 0.3674 - acc: 1.0000 - val_loss: 0.3651 - val_acc: 1.0000
Epoch 100/100
50/50 [==============================] - 18s 368ms/step - loss: 0.3651 - acc: 1.0000 - val_loss: 0.3627 - val_acc: 1.0000

I would love to know if you can reproduce my results and whether you can observe any speed degradation that @fchollet suspects.

@datumbox
Copy link
Contributor Author

@ahundt Thanks for your comment!

In this very specific experiment I used a fixed batch size of 32. Nevertheless in this dummy example I try to reproduce a behaviour we've been facing for over a year now on real-world datasets and problems. In those cases a large number of different batch sizes were tested and the results were comparable.

Please note that his PR DOES NOT undo the recent change where the mean/var no longer shifts when trainable=False. I 100% agree with you that this change is very beneficial. This PR actually takes it a step further and makes sure that the moving mean/var are used instead of the mini-batch statistics when trainable=False. This ensures that the non-frozen layers will be trained on data scaled the same way as in inference mode.

BTW thanks for sending me the discussion on #8616. Give me sometime to read all the details to see how this is related.

@datumbox
Copy link
Contributor Author

@ahundt I've read the discussion on #8616. I understand it focuses on the previous change on BN that correctly stopped the update of the moving mean/var when trainable=False. I totally agree with this change. As I said on my previous comment, this PR takes this a step further to ensure that the data after a frozen BN are scaled in the same way during training as during inference.

What I find interesting is that during the original discussion on #8616, @fchollet raises similar concerns about the semantic meaning of trainable as in this PR. Nevertheless in that discussion, he proposed the introduction of another property to extend the API. I also see he tried to implement another property called "updatable" which was reverted due to the increased complexity (and at the end we settled with extending the semantics of trainable). I wonder if in this case it makes sense to extend the API to cover this valid case OR update the semantics of trainable (preferred solution) OR update the documentation/examples.

I would love to have an opinion from @lukedeo on this since he reviewed the code on the other PR.

@ahundt
Copy link
Contributor

ahundt commented May 1, 2018

@datumbox Ok think I see what you are saying, I might try this out on my dataset. Do I need to change any settings like trainable in my training code or can I just pull this in? In my example I use frozen vgg16 imagenet pretrained weights as a feature extractor with additional trainable layers afterwards.

One thing that might help with getting this through is a few improvements to the PR, variable names, and description. If you better separate the concepts and clarify the conditions under which different data is fixed vs changing the reasons this improves performance may be more obvious.

@ahundt
Copy link
Contributor

ahundt commented May 1, 2018

Ok so based on the test_batchnorm_trainable() changes this should be active by default in all cases except when both learning phase=1 and trainable=True.

# In all other cases we should use the moving mean and variance from BN.

Correct?

@datumbox
Copy link
Contributor Author

datumbox commented May 1, 2018

@ahundt Thanks for looking into this. My PR affects only networks that use Batch Normalization layers, so VGG will not be affected. No additional configuration is required other than setting trainable=False on the BN layers. Pulling this in should work fine just note that my fork is not synced with the latest stable version of Keras; I plan to do this soon as other people are interested. I synced the patch on Keras 2.1.6, Keras 2.2.2 and Keras 2.2.4.

One thing that might help with getting this through is a few improvements to the PR, variable names, and description.

Sure thing, send me your comments and I'll make the changes. :-)

@ahundt
Copy link
Contributor

ahundt commented May 3, 2018

Oh, yeah sorry I did a first run definitely wasn't configured correctly since vgg makes no sense for this case, and the BN layers I had were trained from scratch. I did have other models including resnet and densenet that didn't perform as well as vgg that use pretrained weights, and the fix in this PR might be why. I will try them out but can you confirm the following steps will make use of your changes?

  1. load pretrained weights for densenet (trainable = false)
  2. add some additional layers on the end
  3. set all layers from (1) including bn to trainable = False, layers from (2) to trainable = True
  4. run training script

Should I expect the above sequence to change the performance when running with this PR applied?

edit: fixed typo mentioned in the next post

@captainst
Copy link

captainst commented Oct 8, 2019

After some try out, I think that a fesible yet simple solution is this (pseudo code, taking inceptionV3 as example):

1. K.set_learning_phase(0) # test mode, freezing everything
2. myModel = InceptionV3(weights='imagenet', include_top=False, input_shape=(299,299,3)) # load pre-trained model and weight
3. nn_inputs = myModel.input # save input for layer use
4. for layer in myModel.layers:
        layer.trainable = False # freeze the weights in each layer. Notice that BNs are also freezed
5. K.set_learning_phase(1) # switch to training mode
6. # build the top layers 
    myModelOut = myModel.output
    myModelOut = GlobalAveragePooling2D()(myModelOut)
    myModelOut = Dense(1024, activation="relu")(myModelOut)
    myModelOut = Dense(10, activation="softmax")(myModelOut)
7. # build the whole model
    finalModel = Model(inputs=nn_inputs, outputs=myModelOut)
8. # verify the model structure and parameters
    print(model.summary())

@gjy1992
Copy link

gjy1992 commented Oct 21, 2019

@captainst Hello, this way can work. But after I save the finalModel, and want to continue a training through load the saved_model, I cannot set part of the finalModel be at learning_phase=0. (╥╯^╰╥)

@ec1841
Copy link

ec1841 commented Nov 2, 2019

@captainst your approach doesn't work, when you want to fine-tune top-k layers part of your base-model that may have BN layers.

Yet another work-around :), using @faustomorales suggestion as the base soup to solve the top-k layer fine-tuning.


from keras import layers

class FrozenBatchNormalization(layers.BatchNormalization):
    def call(self, inputs, training=None):
        return super().call(inputs=inputs, training=False)

model = InceptionResNetV2(....)
if mode == 'training':
    _bottom_layers = model.layers[:-top_k_layers]
    _top_layers = model.layers[-top_k_layers:]
elif mode == 'inference':
    _bottom_layers = model.layers
    _top_layers = []

for _layer in _bottom_layers:
    _layer.trainable = False
    if (_is_batch_normalization(_layer)):
        print('Freezing BN layers ... {}'.format(_layer.name))
        _layer = FrozenBatchNormalization

for _layer in _top_layers:
    _layer.trainable = True
    if (_is_batch_normalization(_layer)):
        print('Unfreezing BN layers ... {}'.format(_layer.name))
        _layer = layers.BatchNormalization

Will this work?

@rpeloff
Copy link

rpeloff commented Nov 3, 2019

@lovejing0306

I used Tensorflow 2.0.0-rc0, when I fine-tuning resnet have same problem.

@off99555

In tensorflow2 GPU there is still this problem occurring, I have to use @faustomorales code in order to fix the issue.

For those using TensorFlow 2.0 and trying to fine-tune resnet, inceptionv3, etc., the problem seems to persist due to the injection of tensorflow.python.keras.layers. This references the TF 1.0 behaviour batch normalisation in keras_applications when calling layers.BatchNormalization (for example, in inceptionv3).

Similar to what @faustomorales suggested, I found that simply injecting tf.keras.layers references the TF 2.0 behaviour batch normalisation and fixes this issue (see here for the change in behaviour). When loading models from tf.keras.applications simply add the argument layers=tf.keras.layers. For example:

import tensorflow as tf

pretrained = tf.keras.applications.inception_v3.InceptionV3(
    layers=tf.keras.layers, weights='imagenet')

@Tauranis
Copy link

Tauranis commented Nov 25, 2019

@datumbox,
If one day I meet you, I promise I'll pay you a beer.
You have not idea how this thread saved me. I've spent two weeks struggling with transfer learning without having any clue of why it was going completely wrong. A simple transfer learning, it was non-sense.
Unfortunately, none of the suggested workarounds worked for me. I'm currently using TF 2.0.0.
The only network that you won't have any headaches is VGG once it has not batch norm layers.

For all others, what works in my case is to do transfer-learning on a 2-step process: Extract embeddings first (into tfrecords shards or not, it is up to you) for further classification.

@rpeloff , your workaround worked for me on TF 1.15.0 but not at TF 2.0.0, but thanks anyway.

@sedghi
Copy link

sedghi commented Dec 3, 2019

I can't believe this is not fixed yet
I have spent 1 week to finally find what's going wrong with my code
Thanks @datumbox , I'll buy you a beer too

@sameervk
Copy link

@sedghi, same here. Thanks @datumbox, this was not a problem before when working with custom layers, but when working with Transfer Learning, I just didn't know what on earth was going on until I came across your blog and then this.

@sameervk
Copy link

@Tauranis would you mind elaborating on your 2-step process please? Thanks.

@RomainSabathe
Copy link

RomainSabathe commented Jan 25, 2020

@lovejing0306

I used Tensorflow 2.0.0-rc0, when I fine-tuning resnet have same problem.

@off99555

In tensorflow2 GPU there is still this problem occurring, I have to use @faustomorales code in order to fix the issue.

For those using TensorFlow 2.0 and trying to fine-tune resnet, inceptionv3, etc., the problem seems to persist due to the injection of tensorflow.python.keras.layers. This references the TF 1.0 behaviour batch normalisation in keras_applications when calling layers.BatchNormalization (for example, in inceptionv3).

Similar to what @faustomorales suggested, I found that simply injecting tf.keras.layers references the TF 2.0 behaviour batch normalisation and fixes this issue (see here for the change in behaviour). When loading models from tf.keras.applications simply add the argument layers=tf.keras.layers. For example:

import tensorflow as tf

pretrained = tf.keras.applications.inception_v3.InceptionV3(
    layers=tf.keras.layers, weights='imagenet')

Thank you so much!!! Looks like it solved it for me. That is certainly a strange behaviour though. One would think that they're using TF 2.x components when using the official TF 2.x release. Anyways, thanks for your reply and the explanation.

@sdpenguin
Copy link

sdpenguin commented Feb 18, 2020

To add to the suggestion made by @faustomorales, I found it useful to actually update the call function of the BatchNormalisation layer base class, after creating the layers, rather than creating a subclass. This means that you don't have to modify your original model code for loading. You could use something similar to the following:

import copy
tf.keras.layers.BatchNormalization._original_call_function = copy.deepcopy(tf.keras.layers.BatchNormalization.call) # We potentially need this to be a different object to avoid recursion
def non_trainable_call_function(_self, inputs, training=None):
    return tf.keras.layers.BatchNormalization._original_call_function(_self, inputs=inputs, training=False)
tf.keras.layers.BatchNormalization.call = non_trainable_call_function

This means that when any BatchNormalization layer is called, the moving mean and variance are calculated, rather than the minibatch mean and variance.

Alternately, you could recurse through only your frozen model and individually set the call functions of the layer if it is an instance of BatchNormalization, rather than changing the global behaviour.

This is a problem that needs to be addressed IMO since training with parts of the model frozen that include BatchNorm is a very difficult thing to spot as a problem and can lead to unnecessary wastage of time trying to identify the problem.

@was84san
Copy link

was84san commented Feb 20, 2020

Can anyone simply answer me about If I am using pre-trained model Densenet121 on Imagenet dataset, and this model has BN layers and I am using it to train facial emotions dataset but I am not freezing any layer and I am usingtf.kerasinstead of standalone Keras, should I be affected by the batch normalization behavior?

@danelee2601
Copy link

danelee2601 commented Feb 27, 2020

I suggest one solution by solving the fundamental problem.

The problem

: keras BN layer uses average and variance of a mini-batch while training even when it's frozen(trainable=False), but we want it to use trained moving_average and moving_variance while training when frozen.

Summary of the solution

  1. Copy a trained model
  2. While copying, BN layers are replaced by Lambda layer that normalizes the input from the previous layer in the same way as the BN layer does. (Lambda layer uses moving_average, moving_variance, gamma, beta from the trained model)
  3. set trainable=False

Validation code

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Dense, BatchNormalization, Lambda

# get base_model
base_model = tf.keras.Sequential()
base_model.add(Dense(8, activation='relu', input_shape=(1,)))
base_model.add(BatchNormalization())
base_model.add(Dense(1))
base_model.compile(loss="mse")

# get simple dataset
X = np.arange(0, 1, 0.1).reshape(-1, 1)
Y = np.copy(X)

# fit base_model
base_model.fit(X, Y, epochs=500)

# make frozen_model for validation
# for BatchNormalization layer, we reconstruct the layer with moving_average, moving_variance, gamma, beta from base_model
# the normalization equation of the BN layer is:: gamma*[(z-moving_mean)/sqrt(moving_variance + epsilon)] + beta
frozen_model = tf.keras.Sequential()
frozen_model.add(base_model.layers[0])
frozen_model.add(Lambda(lambda z: base_model.layers[1].gamma*((z-base_model.layers[1].moving_mean)/K.sqrt(base_model.layers[1].moving_variance + base_model.layers[1].epsilon)) + base_model.layers[1].beta))
frozen_model.add(base_model.layers[2])

frozen_model.trainable = False

# verify if the output from the two models are the same
yhat = base_model.predict(X).flatten()
yhat2 = frozen_model.predict(X).flatten()
print(f"""
yhat: {yhat}\n
yhat2: {yhat2}
""")

# make tf_model for transfer learning
tf_model = tf.keras.Sequential()
tf_model.add(base_model.layers[0])
tf_model.add(Lambda(lambda z: base_model.layers[1].gamma*((z-base_model.layers[1].moving_mean)/K.sqrt(base_model.layers[1].moving_variance + base_model.layers[1].epsilon)) + base_model.layers[1].beta))

tf_model.trainable = False  # freeze the existing layers

tf_model.add(Dense(4, activation="relu"))  # add new layer
tf_model.add(Dense(1))  # add new layer

tf_model.compile(loss="mse")

tf_model.fit(X, Y, epochs=10)

@OverLordGoldDragon
Copy link

OverLordGoldDragon commented Feb 29, 2020

Contrary to most, I agree with @fchollet; the existing API can fulfill this PR's intent. The PR does ease the process, at expense of increased computing time per adding an iteration-level conditional to the graph - but it's a valid patch that could've been merged with a printed warning.

The solution is rather simple; use a model-building function with an argument, e.g. make_model(bn_training) that's used as BatchNormalization()(x, training=bn_training). That this "can't be changed after model compilation" isn't a serious limitation, as models can be easily saved/restored uncompiled, and you're likely loading models for transfer learning anyway.

@leokwu
Copy link

leokwu commented Mar 3, 2020

@datumbox Solved part of my puzzle, But there is another way:
The purpose is to adjust and calculate BN parameters with a set of data set. In the first method, the following layers can be fine-tuned and trained first, and all layers can be trained after the parameters of the pre-trained model are available,it works for me.

Anyway tks. @datumbox @fchollet ,

@rafaspadilha
Copy link

For some reason, @datumbox 's patch did not solve my problem with BN layer. I am using tf 2.0.0 and Keras 2.3.1 (in an anaconda environment). I made sure to correctly alter the normalization.py and tensorflow_backend.py. I have run the testing code from @datumbox blog, but the different behavior between training and testing still remains.

I also tried @rpeloff and @faustomorales solution without success on @datumbox testing script.

@farukgogh
Copy link

Hi @rpeloff
did you solve problem just like you mentioned (pretrained = tf.keras.applications.inception_v3.InceptionV3( layers=tf.keras.layers, weights='imagenet'))
) without any patch?

or after set up @datumbox's patch?

if so which patch is suitable? I am using TF2 and keras 2.3.1

@mathandy
Copy link

Regarding tf.keras, there's now a clear explanation of the relevant behavior in the docs:
https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization#output_shape_2

... This behavior has been introduced in TensorFlow 2.0, in order to enable layer.trainable = False to produce the most commonly expected behavior in the convnet fine-tuning use case.

@datumbox
Copy link
Contributor Author

@mathandy Thanks for the reference. I think it makes sense that TF 2 changed the behavior of the layer to facilitate transfer learning.

Perhaps that's an opportunity to resubmit this fix for consideration. I filed a new PR at #13892, let's see if it will be accepted. :)

@rptrevin
Copy link

rptrevin commented Mar 15, 2020

@datumbox Thank you so much! As you mentioned, this has a severe impact when implementing transfer learning and fine tuning part of the model. I was getting high accuracy during training while inference accuracy was dipping by around 30%. I spent a significant amount of time ensuring that math and code were correct. I implemented your fix and inference started giving F1 and accuracy on par with training. This saved me a significant amount of time and head ache. Thank you again!

@raghavab1992
Copy link

@lovejing0306

I used Tensorflow 2.0.0-rc0, when I fine-tuning resnet have same problem.

@off99555

In tensorflow2 GPU there is still this problem occurring, I have to use @faustomorales code in order to fix the issue.

For those using TensorFlow 2.0 and trying to fine-tune resnet, inceptionv3, etc., the problem seems to persist due to the injection of tensorflow.python.keras.layers. This references the TF 1.0 behaviour batch normalisation in keras_applications when calling layers.BatchNormalization (for example, in inceptionv3).

Similar to what @faustomorales suggested, I found that simply injecting tf.keras.layers references the TF 2.0 behaviour batch normalisation and fixes this issue (see here for the change in behaviour). When loading models from tf.keras.applications simply add the argument layers=tf.keras.layers. For example:

import tensorflow as tf

pretrained = tf.keras.applications.inception_v3.InceptionV3(
    layers=tf.keras.layers, weights='imagenet')

This worked perfectly for me. Thanks everyone for their support

@CMCDragonkai
Copy link

CMCDragonkai commented Jul 14, 2020

If I'm still on TF 1.15 (the last 1.x) release and using tf.keras and Keras-Applications. What is the most succinct solution for this problem? Is it?

import tensorflow as tf

pretrained = tf.keras.applications.inception_v3.InceptionV3(
    layers=tf.keras.layers, weights='imagenet')

Or is it... #9965 (comment)

@Shahidul1004
Copy link

everybody is mentioning that @faustomorales's solution works but i can not see that solution.

@TheGuywithTheHat
Copy link

@Shahidul1004 see #9965 (comment)

francescomilano172 added a commit to ethz-asl/background_foreground_segmentation that referenced this pull request Feb 7, 2021
NOTE: it is particularly important that this command also sets the batch-normalization layers to non-trainable, which now seems to be the standard with Tensorflow 2 + Keras, but is not yet handled well by, e.g., the models from `segmentation_models`. Cf. `freeze_model` from `segmentation_models/models/_utils.py` and, e.g., https://keras.io/getting_started/faq/#whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute and keras-team/keras#9965.
@Thunder003
Copy link

Hi, @datumbox, thanks for your contribution. Can you tell me what changes you made on the code side? At least some top-overview(of course apart from what you have mentioned on the blog).
Also, you have written in the blog "Thankfully starting from version 2.1.3, when a BN layer is frozen it no longer updates its statistics. But is that enough? Not if you are using Transfer Learning." What I understood from this is that since Keras 2.1.3 BN layer will be locked up when frozen. So, is there no need for the patch you made if we use Keras 2.1.3?

@AndreyStille
Copy link

@lovejing0306

I used Tensorflow 2.0.0-rc0, when I fine-tuning resnet have same problem.

@off99555

In tensorflow2 GPU there is still this problem occurring, I have to use @faustomorales code in order to fix the issue.

For those using TensorFlow 2.0 and trying to fine-tune resnet, inceptionv3, etc., the problem seems to persist due to the injection of tensorflow.python.keras.layers. This references the TF 1.0 behaviour batch normalisation in keras_applications when calling layers.BatchNormalization (for example, in inceptionv3).

Similar to what @faustomorales suggested, I found that simply injecting tf.keras.layers references the TF 2.0 behaviour batch normalisation and fixes this issue (see here for the change in behaviour). When loading models from tf.keras.applications simply add the argument layers=tf.keras.layers. For example:

import tensorflow as tf

pretrained = tf.keras.applications.inception_v3.InceptionV3(
    layers=tf.keras.layers, weights='imagenet')

But I got this.
TypeError: InceptionV3() got an unexpected keyword argument 'layers'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.