Skip to content

Commit

Permalink
temporary fix for batch norm storage fallback (apache#156)
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin committed Aug 10, 2017
1 parent 2d93d72 commit 80a590d
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 51 deletions.
2 changes: 1 addition & 1 deletion src/operator/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ void BatchNormOp<xpu, DType, AccReal>::DoBackward(mshadow::Stream<cpu> *,
#pragma omp parallel for
for (int channel = 0; channel < static_cast<int>(channelCount); ++channel) {
const AccReal *weight = weights.dptr<AccReal>();
const AccReal w = weight ? weight[channel] : AccReal(1);
const AccReal w = !param_.fix_gamma ? weight[channel] : AccReal(1);
AccReal mean, invstd;
if (is_train_and_not_global_stats) {
mean = saveMeanDataPtr[channel];
Expand Down
114 changes: 64 additions & 50 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,75 +855,89 @@ def test_nearest_upsampling():
check_nearest_upsampling_with_shape(shapes, scale, root_scale)

def test_batchnorm_training():
for shape in [(2, 3), (2, 3, 2, 2)]:
data_tmp = np.random.normal(-0.1, 0.1, size=shape)
s = shape[1],
gamma = np.ones(s)
beta = np.ones(s)
gamma[1] = 3
beta[0] = 3
def check_batchnorm_training(stype):
for shape in [(2, 3), (2, 3, 2, 2)]:
data_tmp = np.random.normal(-0.1, 0.1, size=shape)
s = shape[1],
gamma = np.ones(s)
beta = np.ones(s)
gamma[1] = 3
beta[0] = 3

rolling_mean = np.random.uniform(size=s)
rolling_std = np.random.uniform(size=s)
rolling_mean = np.random.uniform(size=s)
rolling_std = np.random.uniform(size=s)

data = mx.symbol.Variable('data')
stype = 'row_sparse'
data = mx.symbol.Variable('data', stype=stype)
in_location = [mx.nd.array(data_tmp).tostype(stype), mx.nd.array(gamma).tostype(stype),
mx.nd.array(beta).tostype(stype)]
mean_std = [mx.nd.array(rolling_mean).tostype(stype), mx.nd.array(rolling_std).tostype(stype)]

test = mx.symbol.BatchNorm_v1(data, fix_gamma=True)
check_numeric_gradient(test, [data_tmp, gamma, beta], [rolling_mean, rolling_std], numeric_eps=1e-2, rtol=0.16)
test = mx.symbol.BatchNorm_v1(data, fix_gamma=True)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16)

test = mx.symbol.BatchNorm(data, fix_gamma=True)
check_numeric_gradient(test, [data_tmp, gamma, beta], [rolling_mean, rolling_std], numeric_eps=1e-2, rtol=0.16)
test = mx.symbol.BatchNorm(data, fix_gamma=True)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16)

test = mx.symbol.BatchNorm_v1(data, fix_gamma=True, use_global_stats=True)
check_numeric_gradient(test, [data_tmp, gamma, beta], [rolling_mean, rolling_std], numeric_eps=1e-2, rtol=0.16)
test = mx.symbol.BatchNorm_v1(data, fix_gamma=True, use_global_stats=True)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16)

test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True)
check_numeric_gradient(test, [data_tmp, gamma, beta], [rolling_mean, rolling_std], numeric_eps=1e-2, rtol=0.16)
test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16)

test = mx.symbol.BatchNorm_v1(data, fix_gamma=False)
check_numeric_gradient(test, [data_tmp, gamma, beta], [rolling_mean, rolling_std], numeric_eps=1e-2, rtol=0.16)
test = mx.symbol.BatchNorm_v1(data, fix_gamma=False)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16)

test = mx.symbol.BatchNorm(data, fix_gamma=False)
check_numeric_gradient(test, [data_tmp, gamma, beta], [rolling_mean, rolling_std], numeric_eps=1e-2, rtol=0.16)
test = mx.symbol.BatchNorm(data, fix_gamma=False)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16)

test = mx.symbol.BatchNorm_v1(data, fix_gamma=False, use_global_stats=True)
check_numeric_gradient(test, [data_tmp, gamma, beta], [rolling_mean, rolling_std], numeric_eps=1e-2, rtol=0.16)
test = mx.symbol.BatchNorm_v1(data, fix_gamma=False, use_global_stats=True)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16)

test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True)
check_numeric_gradient(test, [data_tmp, gamma, beta], [rolling_mean, rolling_std], numeric_eps=1e-2, rtol=0.16)
test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16)

# Test varying channel axis
dim = len(shape)
for chaxis in range(-dim, dim):
chaxis_true = chaxis
if chaxis < 0:
chaxis_true = dim + chaxis
# Test varying channel axis
dim = len(shape)
for chaxis in range(-dim, dim):
chaxis_true = chaxis
if chaxis < 0:
chaxis_true = dim + chaxis

shapex = shape
shapex = shape

channel_count = shapex[chaxis_true]
data_tmp = np.random.normal(-0.1, 0.1, size=shapex)
channel_count = shapex[chaxis_true]
data_tmp = np.random.normal(-0.1, 0.1, size=shapex)

gamma = np.ones(channel_count)
beta = np.ones(channel_count)
if channel_count > 1:
gamma[1] = 3
beta[0] = 3
gamma = np.ones(channel_count)
beta = np.ones(channel_count)
if channel_count > 1:
gamma[1] = 3
beta[0] = 3

in_location = [mx.nd.array(data_tmp).tostype(stype), mx.nd.array(gamma).tostype(stype),
mx.nd.array(beta).tostype(stype)]

xrolling_mean = np.random.uniform(size=channel_count)
xrolling_std = np.random.uniform(size=channel_count)
xmean_std = [mx.nd.array(xrolling_mean).tostype(stype),
mx.nd.array(xrolling_std).tostype(stype)]

xrolling_mean = np.random.uniform(size=channel_count)
xrolling_std = np.random.uniform(size=channel_count)
test = mx.symbol.BatchNorm(data, fix_gamma=True, axis=chaxis)
check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01)

test = mx.symbol.BatchNorm(data, fix_gamma=True, axis=chaxis)
check_numeric_gradient(test, [data_tmp, gamma, beta], [xrolling_mean, xrolling_std], numeric_eps=1e-2, rtol=0.2, atol=0.01)
test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True, axis=chaxis)
check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01)

test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True, axis=chaxis)
check_numeric_gradient(test, [data_tmp, gamma, beta], [xrolling_mean, xrolling_std], numeric_eps=1e-2, rtol=0.2, atol=0.01)
test = mx.symbol.BatchNorm(data, fix_gamma=False, axis=chaxis)
check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01)

test = mx.symbol.BatchNorm(data, fix_gamma=False, axis=chaxis)
check_numeric_gradient(test, [data_tmp, gamma, beta], [xrolling_mean, xrolling_std], numeric_eps=1e-2, rtol=0.2, atol=0.01)
test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True, axis=chaxis)
check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01)

test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True, axis=chaxis)
check_numeric_gradient(test, [data_tmp, gamma, beta], [xrolling_mean, xrolling_std], numeric_eps=1e-2, rtol=0.2, atol=0.01)
stypes = ['row_sparse', 'csr', 'default']
for stype in stypes:
check_batchnorm_training(stype)

def test_convolution_grouping():
num_filter = 4
Expand Down

0 comments on commit 80a590d

Please sign in to comment.