From ae7a104fc19d3586db932235d5eb8da3b123e8dc Mon Sep 17 00:00:00 2001 From: PiotrWolinski - Intel Date: Mon, 7 Mar 2022 11:54:18 +0100 Subject: [PATCH] Fixed issue with batchnorm on even number of channels (#20927) --- src/operator/contrib/batch_norm_relu.cc | 10 +++--- src/operator/nn/batch_norm.cc | 10 +++--- .../nn/mkldnn/mkldnn_batch_norm-inl.h | 36 +++++++++++++++---- 3 files changed, 42 insertions(+), 14 deletions(-) diff --git a/src/operator/contrib/batch_norm_relu.cc b/src/operator/contrib/batch_norm_relu.cc index d1f409c975a5..e9a0e9581407 100644 --- a/src/operator/contrib/batch_norm_relu.cc +++ b/src/operator/contrib/batch_norm_relu.cc @@ -138,12 +138,13 @@ void BatchNormWithReLUComputeExCPU(const nnvm::NodeAttrs &attrs, const std::vector &outputs) { CHECK_EQ(inputs.size(), 5U); const BatchNormParam ¶m = nnvm::get(attrs.parsed); - bool fuse_relu = true; + if (SupportMKLDNNBNReLU(inputs[0], param)) { CHECK_GT(outputs.size(), 3U); MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); MKLDNN_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, { - MKLDNNBatchNormForward(attrs, ctx, inputs, req, outputs, fuse_relu); + MKLDNNRun(MKLDNNBatchNormForward, attrs, ctx, + inputs, req, outputs); }); return; } @@ -156,11 +157,12 @@ void BatchNormWithReLUGradComputeExCPU(const nnvm::NodeAttrs &attrs, const std::vector &req, const std::vector &outputs) { const BatchNormParam ¶m = nnvm::get(attrs.parsed); - bool fuse_relu = true; + if (SupportMKLDNNBNReLU(inputs[0], param)) { CHECK_EQ(inputs.size(), 9U); MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); - MKLDNNBatchNormBackward(attrs, ctx, inputs, req, outputs, fuse_relu); + MKLDNNRun(MKLDNNBatchNormBackward, attrs, ctx, + inputs, req, outputs); return; } LOG(FATAL) << "BatchNormWithReLU operator only supports MKL-DNN Backend."; diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 7701099dc801..ef39f228fe2f 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -452,11 +452,12 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs, const std::vector &outputs) { CHECK_EQ(inputs.size(), 5U); const BatchNormParam ¶m = nnvm::get(attrs.parsed); - bool fuse_relu = false; + if (SupportMKLDNNBN(inputs[0], param)) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); MKLDNN_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, { - MKLDNNBatchNormForward(attrs, ctx, inputs, req, outputs, fuse_relu); + MKLDNNRun(MKLDNNBatchNormForward, attrs, ctx, + inputs, req, outputs); }); MKLDNN_OPCHECK_RUN(BatchNormCompute, attrs, ctx, inputs, req, outputs); return; @@ -470,10 +471,11 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs, const std::vector &req, const std::vector &outputs) { const BatchNormParam ¶m = nnvm::get(attrs.parsed); - bool fuse_relu = false; + if (SupportMKLDNNBN(inputs[0], param)) { MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); - MKLDNNBatchNormBackward(attrs, ctx, inputs, req, outputs, fuse_relu); + MKLDNNRun(MKLDNNBatchNormBackward, attrs, ctx, + inputs, req, outputs); MKLDNN_OPCHECK_RUN(BatchNormGradCompute, attrs, ctx, inputs, req, outputs); return; } diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index 75c7c4dbf38a..b443d3c92831 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -146,9 +146,12 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, } template -void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, - const std::vector &inputs, const std::vector &req, - const std::vector &outputs, bool fuse_relu) { +void MKLDNNBatchNormForwardImpl(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs, + bool fuse_relu) { const BatchNormParam ¶m = nnvm::get(attrs.parsed); std::vector in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean); @@ -263,6 +266,15 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, } } +template +void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + MKLDNNBatchNormForwardImpl(attrs, ctx, inputs, req, outputs, + fuse_relu); +} + class MKLDNNBNBackward { std::shared_ptr bwd; const std::shared_ptr weight_m; @@ -310,9 +322,12 @@ static MKLDNNBNBackward &GetBNBackward( } template -void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, - const std::vector &inputs, const std::vector &req, - const std::vector &outputs, bool fuse_relu) { +void MKLDNNBatchNormBackwardImpl(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs, + bool fuse_relu) { if (fuse_relu) { CHECK_EQ(inputs.size(), 9U); } else { @@ -477,6 +492,15 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, LOG(FATAL) << "MKLDNN batch normalization backward: should not reach here ..."; } } + +template +void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + MKLDNNBatchNormBackwardImpl(attrs, ctx, inputs, req, outputs, + fuse_relu); +} } // namespace op } // namespace mxnet #endif // MXNET_USE_MKLDNN