Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.7.x] backport Invoke mkldnn and cudnn BatchNorm when axis != 1 t…
Browse files Browse the repository at this point in the history
…o v1.7.x (#18676) (#18890)

* [Improvement] Invoke mkldnn and cudnn BatchNorm when axis != 1 (#18504)

* fix batch norm when fix_gamma is True

* support gradient accumulation for batch norm

* mkldnn batchnorm support grad add

* unittest for bn

* fix bn arg

* fix lint

* fix mkldnn

* fix mkldnn bn

* fix grad when fixing gamma

* fix naive gpu bn

* fix lint

* invoke mkldnn and cudnn batchnorm when axis != 1

* backport 18500

* change condition

* fix

* fix

* add mkldnn_off for bn

* remove mkldnn_off

* recover save_000800.json

* cast

* remove  and fix flaky test

Co-authored-by: JackieWu <wkcn@live.cn>

Co-authored-by: JackieWu <wkcn@live.cn>
  • Loading branch information
stu1130 and wkcn authored Aug 14, 2020
1 parent d2d6408 commit d32ba4f
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 52 deletions.
12 changes: 8 additions & 4 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -422,10 +422,14 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,

#if MXNET_USE_MKLDNN == 1
static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam &param) {
mxnet::TShape shape = input.shape();
return SupportMKLDNN(input) && shape.ndim() == 4
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS
&& !mxnet::op::batchnorm::disable_mkl;
if (mxnet::op::batchnorm::disable_mkl) return false;
const mxnet::TShape shape = input.shape();
const int ndim = shape.ndim();
if (ndim == 0 || shape.Size() == 0) return false;
const int dtype = input.dtype();
return (dtype == mshadow::kFloat32 ||
dtype == mshadow::kBfloat16) &&
SupportStorageMKLDNN(input.storage_type());
}

void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs,
Expand Down
6 changes: 2 additions & 4 deletions src/operator/nn/batch_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -704,8 +704,7 @@ void BatchNormCompute<gpu>(const nnvm::NodeAttrs& attrs,

param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1
if (!param.use_global_stats && !param.cudnn_off
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) {
if (!param.use_global_stats && !param.cudnn_off) {
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
GetCuDNNOp<DType>(param).Forward(ctx, in_data, req, outputs, aux_states);
})
Expand Down Expand Up @@ -733,8 +732,7 @@ void BatchNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs,

param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1
if (!param.use_global_stats && !param.cudnn_off
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) {
if (!param.use_global_stats && !param.cudnn_off) {
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
GetCuDNNOp<DType>(param).Backward(ctx, inputs, req, outputs);
})
Expand Down
26 changes: 19 additions & 7 deletions src/operator/nn/cudnn/cudnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,15 +260,27 @@ class CuDNNBatchNormOp {

private:
void Init(const TBlob &in_data) {
if (in_data.ndim() == 4) {
for (int i = 0; i < 4; ++i)
shape_[i] = in_data.shape_[i];
CHECK_GE(param_.axis, 0);
CHECK_LT(param_.axis, in_data.ndim());
if (param_.axis == 1) {
if (in_data.ndim() == 4) {
for (int i = 0; i < 4; ++i)
shape_[i] = in_data.shape_[i];
} else {
// when in_data.ndim() != 4
shape_[0] = in_data.shape_[0];
shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1;
shape_[2] = 1;
shape_[3] = static_cast<dim_t>(in_data.shape_.ProdShape(2,
in_data.ndim()));
}
} else {
// when in_data.ndim() != 4
shape_[0] = in_data.shape_[0];
shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1;
// reshape to (N, C, 1, D), C is the `param_.axis` dimension
shape_[0] = static_cast<dim_t>(in_data.shape_.ProdShape(0, param_.axis));
shape_[1] = in_data.shape_[param_.axis];
shape_[2] = 1;
shape_[3] = in_data.shape_.ProdShape(2, in_data.ndim());
shape_[3] = static_cast<dim_t>(in_data.shape_.ProdShape(param_.axis + 1,
in_data.ndim()));
}

CUDNN_CALL(cudnnSetTensor4dDescriptor(io_desc_,
Expand Down
44 changes: 39 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,25 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs, const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs, bool fuse_relu) {
const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
const std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);
std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);

mxnet::TShape shape = inputs[batchnorm::kData].shape();
const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
CHECK_LT(real_axis, shape.ndim());
NDArray out = outputs[batchnorm::kOut];
if (param.axis != 1 || shape.ndim() != 4) {
// reshape to (N, C, 1, D)
mxnet::TShape new_shape{
static_cast<dim_t>(shape.ProdShape(0, real_axis)),
shape[real_axis],
1,
static_cast<dim_t>(shape.ProdShape(real_axis + 1,
static_cast<int>(shape.ndim())))
};
in_data[batchnorm::kData] = in_data[batchnorm::kData].Reshape(new_shape);
out = out.Reshape(new_shape);
}

const std::vector<NDArray> aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end());
TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
mkldnn::normalization_flags flags = _GetFlags(in_data,
Expand All @@ -166,7 +184,6 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
fuse_relu);
const NDArray &data = in_data[batchnorm::kData];
auto &fwd = GetBNForward<DType>(param, ctx, data, flags);
const NDArray &out = outputs[batchnorm::kOut];

// for output memory
auto out_mem = const_cast<NDArray &>(out).CreateMKLDNNData(fwd.GetPd().dst_desc());
Expand Down Expand Up @@ -325,9 +342,9 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
ctx.is_train && !param.use_global_stats,
fuse_relu);

const NDArray &data = in_data[batchnorm::kData];
const NDArray &diff = out_grad[batchnorm::kOut];
const NDArray &gradIn = in_grad[batchnorm::kData];
NDArray data = in_data[batchnorm::kData];
NDArray diff = out_grad[batchnorm::kOut];
NDArray gradIn = in_grad[batchnorm::kData];
const NDArray &moving_mean = aux_states[batchnorm::kMovingMean];
const NDArray &moving_var = aux_states[batchnorm::kMovingVar];
const NDArray &out_mean = out_data[batchnorm::kMean];
Expand All @@ -338,6 +355,23 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
CHECK(moving_mean.IsDefaultData());
CHECK(moving_var.IsDefaultData());

mxnet::TShape shape = data.shape();
const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
CHECK_LT(real_axis, shape.ndim());
if (param.axis != 1 || shape.ndim() != 4) {
// reshape to (N, C, 1, D)
mxnet::TShape new_shape{
static_cast<dim_t>(shape.ProdShape(0, real_axis)),
shape[real_axis],
1,
static_cast<dim_t>(shape.ProdShape(real_axis + 1,
static_cast<int>(shape.ndim())))
};
data = data.Reshape(new_shape);
diff = diff.Reshape(new_shape);
gradIn = gradIn.Reshape(new_shape);
}

auto data_mem = data.GetMKLDNNData();
auto diff_mem = diff.GetMKLDNNData();
// MKLDNN batchnorm should run on special layouts. If one of them isn't, we
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1541,7 +1541,7 @@ def _test_batchnorm_impl(shape, fix_gamma, cudnn_off, output_mean_var,
assert_almost_equal(
bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol)

shapes = [(24, 2), (24, 3, 4), (24, 8, 4, 5), (24, 5, 6, 4, 5)]
shapes = [(4, 2), (4, 3, 4), (4, 6, 4, 5), (4, 5, 6, 4, 5)]
bools = [False, True]
for shape, fix_gamma, cudnn_off, output_mean_var in itertools.product(
shapes, bools, bools, bools):
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1964,7 +1964,7 @@ def _test_batchnorm_impl(op_name, shape, fix_gamma, cudnn_off, output_mean_var,
bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol)

op_names = ['BatchNorm', 'SyncBatchNorm']
shapes = [(24, 2), (24, 3, 4), (24, 8, 4, 5), (24, 5, 6, 4, 5)]
shapes = [(4, 2), (4, 3, 4), (4, 6, 4, 5), (4, 5, 6, 4, 5)]
bools = [False, True]
for op_name, shape, fix_gamma, cudnn_off, output_mean_var in itertools.product(
op_names, shapes, bools, bools, bools):
Expand Down
30 changes: 0 additions & 30 deletions tests/python/unittest/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,36 +272,6 @@ def check_symbol_consistency(sym1, sym2, ctx, skip_grad=False, equal_nan=False):
grad_req='null' if skip_grad else 'write',
equal_nan=equal_nan)

def test_load_000800():
with mx.AttrScope(ctx_group='stage1'):
data = mx.symbol.Variable('data', lr_mult=0.2)
weight = mx.sym.Variable(name='fc1_weight', lr_mult=1.2)
fc1 = mx.symbol.FullyConnected(data = data, weight=weight, name='fc1', num_hidden=128, wd_mult=0.3)
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")

set_stage1 = set(act1.list_arguments())
with mx.AttrScope(ctx_group='stage2'):
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64, lr_mult=0.01)
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10)
fc3 = mx.symbol.BatchNorm(fc3, name='batchnorm0')
sym1 = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax')

curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sym2 = mx.sym.load(os.path.join(curr_path, 'save_000800.json'))

attr1 = sym1.attr_dict()
attr2 = sym2.attr_dict()
for k, v1 in attr1.items():
assert k in attr2, k
v2 = attr2[k]
for kk, vv1 in v1.items():
if kk.startswith('__') and kk.endswith('__'):
assert kk in v2 and v2[kk] == vv1, k + str(v1) + str(v2)

check_symbol_consistency(sym1, sym2,
{'ctx': mx.cpu(0), 'group2ctx': {'stage1' : mx.cpu(1), 'stage2' : mx.cpu(2)}, 'data': (1,200)})


def test_blockgrad():
a = mx.sym.Variable('a')
Expand Down

0 comments on commit d32ba4f

Please sign in to comment.