Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fixed issue with batchnorm on even number of channels (#20927)
Browse files Browse the repository at this point in the history
  • Loading branch information
piotrwolinski-intel authored Mar 7, 2022
1 parent edba375 commit ae7a104
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 14 deletions.
10 changes: 6 additions & 4 deletions src/operator/contrib/batch_norm_relu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,13 @@ void BatchNormWithReLUComputeExCPU(const nnvm::NodeAttrs &attrs,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 5U);
const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
bool fuse_relu = true;

if (SupportMKLDNNBNReLU(inputs[0], param)) {
CHECK_GT(outputs.size(), 3U);
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNN_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, {
MKLDNNBatchNormForward<DTYPE>(attrs, ctx, inputs, req, outputs, fuse_relu);
MKLDNNRun(MKLDNNBatchNormForward<DTYPE, /*fuse_relu*/ true>, attrs, ctx,
inputs, req, outputs);
});
return;
}
Expand All @@ -156,11 +157,12 @@ void BatchNormWithReLUGradComputeExCPU(const nnvm::NodeAttrs &attrs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
bool fuse_relu = true;

if (SupportMKLDNNBNReLU(inputs[0], param)) {
CHECK_EQ(inputs.size(), 9U);
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
MKLDNNBatchNormBackward<float>(attrs, ctx, inputs, req, outputs, fuse_relu);
MKLDNNRun(MKLDNNBatchNormBackward<float, /*fuse_relu*/ true>, attrs, ctx,
inputs, req, outputs);
return;
}
LOG(FATAL) << "BatchNormWithReLU operator only supports MKL-DNN Backend.";
Expand Down
10 changes: 6 additions & 4 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -452,11 +452,12 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 5U);
const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
bool fuse_relu = false;

if (SupportMKLDNNBN(inputs[0], param)) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNN_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, {
MKLDNNBatchNormForward<DTYPE>(attrs, ctx, inputs, req, outputs, fuse_relu);
MKLDNNRun(MKLDNNBatchNormForward<DTYPE, /*fuse_relu*/ false>, attrs, ctx,
inputs, req, outputs);
});
MKLDNN_OPCHECK_RUN(BatchNormCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
Expand All @@ -470,10 +471,11 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
bool fuse_relu = false;

if (SupportMKLDNNBN(inputs[0], param)) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
MKLDNNBatchNormBackward<float>(attrs, ctx, inputs, req, outputs, fuse_relu);
MKLDNNRun(MKLDNNBatchNormBackward<float, /*fuse_relu*/ false>, attrs, ctx,
inputs, req, outputs);
MKLDNN_OPCHECK_RUN(BatchNormGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
Expand Down
36 changes: 30 additions & 6 deletions src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,12 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
}

template <typename DType>
void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs, const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs, bool fuse_relu) {
void MKLDNNBatchNormForwardImpl(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs,
bool fuse_relu) {
const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);

Expand Down Expand Up @@ -263,6 +266,15 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
}
}

template <typename DType, bool fuse_relu>
void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
MKLDNNBatchNormForwardImpl<DType>(attrs, ctx, inputs, req, outputs,
fuse_relu);
}

class MKLDNNBNBackward {
std::shared_ptr<mkldnn::batch_normalization_backward> bwd;
const std::shared_ptr<mkldnn::memory> weight_m;
Expand Down Expand Up @@ -310,9 +322,12 @@ static MKLDNNBNBackward &GetBNBackward(
}

template <typename DType>
void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs, const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs, bool fuse_relu) {
void MKLDNNBatchNormBackwardImpl(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs,
bool fuse_relu) {
if (fuse_relu) {
CHECK_EQ(inputs.size(), 9U);
} else {
Expand Down Expand Up @@ -477,6 +492,15 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
LOG(FATAL) << "MKLDNN batch normalization backward: should not reach here ...";
}
}

template <typename DType, bool fuse_relu>
void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
MKLDNNBatchNormBackwardImpl<DType>(attrs, ctx, inputs, req, outputs,
fuse_relu);
}
} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN
Expand Down

0 comments on commit ae7a104

Please sign in to comment.