-
Notifications
You must be signed in to change notification settings - Fork 18.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BatchNorm numerical instability #3963
Comments
I rewrote BN layer ( see https://github.com/drnikolaev/caffe/tree/caffe-0.15 |
@borisgin the implementation at https://github.com/borisgin/caffe/tree/caffe-0.15 still has the problem described by @classner |
@borisgin sorry, the above comment is wrong. The instability problem is NOT exist in your implementation, but there are some tiny mistakes |
@classner I did a unit test of BatchNormLayer according to your statement. According to the debug info, the computed mean is right. The deviation happens in the second |
Thank you, @wlike, for looking into this. I wanted to post the code, but was quite busy last week and could not catch up. I had set up a near identical test case. For me, the mean already was off. It also depends on where the normalization factor is multiplied in in the mean calculation (first gemv or second). The result remains approximately the same, though... |
I pushed a new version of BN https://github.com/borisgin/caffe/tree/caffe-0.15-bn
|
I just tracked down an issue with unexpected BatchNorm behavior in the current caffe version. The effect can be condensed to the following scenario: for an input batch with a channel with constant values, it produces (depending on the magnitude of the values) quite different results from zero. The expected normalization would be zeros in this channel for all values (mean normalized to zeros, no variance).
To reproduce the effect, I created a test with a 1x1x3x3 input blob where all entries are 100 (also works for smaller values). In this case, the computed mean is 100.000015 and the result of the normalization is ~ -0.00483, which is far away from 0. The normalization factor increases the numerical deviation. This just highlights how numerically badly conditioned the operation is and how important it is to get the mean right (the variance is then computed as E((X-E(X))^2) ). One approach to improve could be the stable online mean and variance computation by Knuth applied in regions with a reduction.
Interestingly, chainer and mxnet do not seem to suffer from this effect. Chainer uses a reduction kernel for the mean computation here and mxnet like this. Probably the new cudnn versions handle this similarly.
This raises two issues:
For reference:
The text was updated successfully, but these errors were encountered: