Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

BatchNorm can not converge with scale=False #18475

Closed
nttstar opened this issue Jun 3, 2020 · 7 comments
Closed

BatchNorm can not converge with scale=False #18475

nttstar opened this issue Jun 3, 2020 · 7 comments
Labels

Comments

@nttstar
Copy link

nttstar commented Jun 3, 2020

Description

BatchNorm operator with scale=False can not converge.

Error Message

No error message, but loss value and training accuracy is abnormal comparing with scale=True BatchNorm.

To Reproduce

We can try https://github.com/nttstar/arcface.np to train arcface. Add one BatchNorm op with scale=False after final embedding layer

What have you tried to solve it?

  1. Set Scale=True, it can work but with slightly worse test accuracy.

Environment

----------Python Info----------
Version : 3.6.9
Compiler : GCC 7.3.0
Build : ('default', 'Jul 30 2019 19:07:31')
Arch : ('64bit', '')
------------Pip Info-----------
Version : 19.3.1
Directory : /root/anaconda2/envs/py36/lib/python3.6/site-packages/pip
----------MXNet Info-----------
Version : 2.0.0
Directory : /root/anaconda2/envs/py36/lib/python3.6/site-packages/mxnet
Num GPUs : 8
Hashtag not found. Not installed from pre-built package.
----------System Info----------
Platform : Linux-3.10.0-327.el7.x86_64-x86_64-with-centos-7.5.1804-Core
system : Linux
node : gpu06
release : 3.10.0-327.el7.x86_64
version : #1 SMP Thu Nov 19 22:10:57 UTC 2015

@nttstar nttstar added the Bug label Jun 3, 2020
@wkcn
Copy link
Member

wkcn commented Jun 4, 2020

Hi @nttstar , there was a bug in BatchNorm in the previous version of MXNet. The bug was reported in #18373 , and it was fixed in PR #18377 . Could you please try the latest version of MXNet?

@sxjscience
Copy link
Member

@wkcn I communicated with @nttstar offline and #18373 should not be the casue. Would you help try with scale=False and see if there is anything wrong?

@wkcn
Copy link
Member

wkcn commented Jun 5, 2020

@sxjscience I'm sorry that I don't have any machine with GPU to check it recently.

I read the code of batch norm and its unittest.
There is a gradient check when fix_gamma=True:

https://github.com/apache/incubator-mxnet/blob/master/tests/python/unittest/test_operator.py#L1777,

but no output check when fix_gamma=True:

https://github.com/apache/incubator-mxnet/blob/master/tests/python/unittest/test_operator.py#L1882

@wkcn
Copy link
Member

wkcn commented Jun 5, 2020

I try to test the batch norm with fix_gamma=True.
The result on CPU is right, but that on GPU is wrong when fix_gamma=True, axis = 1 and cudnn_off=False. The bn_beta.grad is the only wrong value.

Here are the failure cases when fix_gamma=True.

operator shape axis cudnn_off output_mean_var
BatchNorm (24, 2) 1 False False
BatchNorm (24, 2) 1 False True
BatchNorm (24, 3, 4) 1 False False
BatchNorm (24, 3, 4) 1 False True
BatchNorm (24, 4, 4, 4) 1 False False
BatchNorm (24, 4, 4, 4) 1 False True
BatchNorm (24, 8, 4, 4) 1 False False
BatchNorm (24, 8, 4, 4) 1 False True
BatchNorm (24, 5, 6, 4, 4) 1 False False
BatchNorm (24, 5, 6, 4, 4) 1 False True

@wkcn
Copy link
Member

wkcn commented Jun 5, 2020

I'm fixing the bug and I will submit a PR later.

@wkcn
Copy link
Member

wkcn commented Jun 10, 2020

Hi @nttstar , the bug of BatchNorm has been fixed in #18500 .

Thank you for the report!

@nttstar
Copy link
Author

nttstar commented Jun 10, 2020

@wkcn Thanks! I will check it in the next pip package release.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests

3 participants