Skip to content

Commit

Permalink
move the position of MKL_Compute
Browse files Browse the repository at this point in the history
  • Loading branch information
juliusshufan committed Apr 18, 2019
1 parent f96c34a commit 06c51e9
Showing 1 changed file with 21 additions and 22 deletions.
43 changes: 21 additions & 22 deletions src/operator/tensor/elemwise_unary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,27 @@ class UnaryOp : public OpBase {
}

#if MSHADOW_USE_MKL == 1
template<typename OP, typename MKL_OP>
static void MKL_Compute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
if (req[0] == kNullOp) return;
auto type_flag = inputs[0].type_flag_;
size_t input_size = inputs[0].Size();
if ((req[0] == kWriteTo || req[0] == kWriteInplace) &&
mkl_func::check_size(input_size) &&
mkl_func::check_type(type_flag)) {
// set DType as float or double according to type_flag
MSHADOW_SGL_DBL_TYPE_SWITCH(type_flag, DType, {
MKL_OP::Vectorize(input_size, inputs[0].dptr<DType>(), outputs[0].dptr<DType>());
});
} else {
Compute<cpu, OP>(attrs, ctx, inputs, req, outputs);
}
}

template<typename OP, typename MKL_OP>
static void MKL_ComputeEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down Expand Up @@ -375,28 +396,6 @@ class UnaryOp : public OpBase {
}
}

#if MSHADOW_USE_MKL == 1
template<typename OP, typename MKL_OP>
static void MKL_Compute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
if (req[0] == kNullOp) return;
auto type_flag = inputs[0].type_flag_;
size_t input_size = inputs[0].Size();
if ((req[0] == kWriteTo || req[0] == kWriteInplace) &&
mkl_func::check_size(input_size) &&
mkl_func::check_type(type_flag)) {
// set DType as float or double according to type_flag
MSHADOW_SGL_DBL_TYPE_SWITCH(type_flag, DType, {
MKL_OP::Vectorize(input_size, inputs[0].dptr<DType>(), outputs[0].dptr<DType>());
});
} else {
Compute<cpu, OP>(attrs, ctx, inputs, req, outputs);
}
}
#endif // MSHADOW_USE_MKL == 1
};

/*! \brief Map legacy unary_bwd to backward_grad */
Expand Down

0 comments on commit 06c51e9

Please sign in to comment.