Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add exclude debug flag
Browse files Browse the repository at this point in the history
  • Loading branch information
azai91 committed Aug 9, 2018
1 parent db6f1cd commit d0a68c6
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/operator/nn/lrn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ NNVM_REGISTER_OP(_backward_LRN)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
#if MXNET_USE_MKLDNN == 1
.set_attr<FComputeEx>("FComputeEx<cpu>", LRNGradComputeExCPU)
// Native compute requires norm while MKLDNN does not so cannot be compared in debug mode
.set_attr<bool>("TExcludeMKLDNNDebug", true)
#endif
.set_attr<FCompute>("FCompute<cpu>", LRNGradCompute<cpu>);

Expand Down
5 changes: 5 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,11 @@ void OpCheck::Run(mxnet::FCompute fn, const nnvm::NodeAttrs &attrs,
const std::vector<mxnet::NDArray> &inputs_,
const std::vector<mxnet::OpReqType> &req,
const std::vector<mxnet::NDArray> &outputs_) {

static auto& is_excluded = Op::GetAttr<bool>("TExcludeMKLDNNDebug");

if (is_excluded.get(attrs.op, false)) return;

std::vector<mxnet::TBlob> in_blobs(inputs.size());
for (size_t i = 0; i < in_blobs.size(); i++) in_blobs[i] = inputs[i].data();
std::vector<mxnet::TBlob> out_blobs(outputs.size());
Expand Down

0 comments on commit d0a68c6

Please sign in to comment.