diff --git a/src/operator/nn/lrn.cc b/src/operator/nn/lrn.cc index 6b3d7c818378..28da2712e373 100644 --- a/src/operator/nn/lrn.cc +++ b/src/operator/nn/lrn.cc @@ -204,6 +204,8 @@ NNVM_REGISTER_OP(_backward_LRN) .set_attr("TIsBackward", true) #if MXNET_USE_MKLDNN == 1 .set_attr("FComputeEx", LRNGradComputeExCPU) +// Native compute requires norm while MKLDNN does not so cannot be compared in debug mode +.set_attr("TExcludeMKLDNNDebug", true) #endif .set_attr("FCompute", LRNGradCompute); diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc index 8ad42a46c1a0..f07e84cc1adc 100644 --- a/src/operator/nn/mkldnn/mkldnn_base.cc +++ b/src/operator/nn/mkldnn/mkldnn_base.cc @@ -496,6 +496,11 @@ void OpCheck::Run(mxnet::FCompute fn, const nnvm::NodeAttrs &attrs, const std::vector &inputs_, const std::vector &req, const std::vector &outputs_) { + + static auto& is_excluded = Op::GetAttr("TExcludeMKLDNNDebug"); + + if (is_excluded.get(attrs.op, false)) return; + std::vector in_blobs(inputs.size()); for (size_t i = 0; i < in_blobs.size(); i++) in_blobs[i] = inputs[i].data(); std::vector out_blobs(outputs.size());