diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc index 2be6b2baca63..42ef8ff15efb 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc @@ -23,7 +23,7 @@ * \brief */ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "../quantized_elemwise_add-inl.h" #include "../../nn/mkldnn/mkldnn_ops-inl.h" #include "../../nn/mkldnn/mkldnn_base-inl.h" @@ -73,17 +73,17 @@ static void MKLDNNQuantizedElemwiseAddForward(const nnvm::NodeAttrs& attrs, cons // output default set as int32 float output_data_range = kInt32Range; - auto output_data_type = mkldnn::memory::s32; + auto output_data_type = mkldnn::memory::data_type::s32; // dataA && dataB are uint8 if (out_data[quantized_elemwise_add_enum::kOut].dtype() == mshadow::kInt8) { output_data_range = kInt8Range; - output_data_type = mkldnn::memory::s8; + output_data_type = mkldnn::memory::data_type::s8; } else if (out_data[quantized_elemwise_add_enum::kOut].dtype() == mshadow::kUint8) { output_data_range = kUint8Range; - output_data_type = mkldnn::memory::u8; + output_data_type = mkldnn::memory::data_type::u8; } else { output_data_range = kInt32Range; - output_data_type = mkldnn::memory::s32; + output_data_type = mkldnn::memory::data_type::s32; } float output_min = 0; @@ -100,12 +100,13 @@ static void MKLDNNQuantizedElemwiseAddForward(const nnvm::NodeAttrs& attrs, cons // 2: scale 0 for dataA, scale 1 for data B const int scales_num = 2; std::vector scales(scales_num, 1); + auto engine = CpuEngine::Get()->get_engine(); if (in_data[quantized_elemwise_add_enum::kDataA].dtype() != in_data[quantized_elemwise_add_enum::kDataB].dtype()) { - auto s8_pd = (is_dataA_int8 == true) - ? dataA_mem->get_primitive_desc() - : dataB_mem->get_primitive_desc(); - rescaled_mem = TmpMemMgr::Get()->Alloc(s8_pd); + auto s8_desc = (is_dataA_int8 == true) + ? dataA_mem->get_desc() + : dataB_mem->get_desc(); + rescaled_mem = TmpMemMgr::Get()->Alloc(s8_desc); float u8_reorder_scale = 0; if (params.max_calib_range.has_value() && params.min_calib_range.has_value()) { if (is_dataA_int8 == true) { @@ -130,14 +131,16 @@ static void MKLDNNQuantizedElemwiseAddForward(const nnvm::NodeAttrs& attrs, cons } } std::vector reorder_scale = {u8_reorder_scale}; - primitive_attr reorder_attr; - reorder_attr.set_int_output_round_mode(round_mode::round_nearest); + mkldnn::primitive_attr reorder_attr; reorder_attr.set_output_scales(0, reorder_scale); auto u8_mem = (is_dataA_int8 == true) ? dataB_mem : dataA_mem; - const auto reorder_pd = mkldnn::reorder::primitive_desc(u8_mem->get_primitive_desc(), - s8_pd, + const auto reorder_pd = mkldnn::reorder::primitive_desc(engine, + u8_mem->get_desc(), + engine, + s8_desc, reorder_attr); - MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *u8_mem, *rescaled_mem)); + mkldnn_args_map_t args({{MKLDNN_ARG_FROM, *u8_mem }, {MKLDNN_ARG_TO, *rescaled_mem}}); + MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::reorder(reorder_pd), args); if (is_dataA_int8 == true) { dataB_mem = rescaled_mem; @@ -155,27 +158,26 @@ static void MKLDNNQuantizedElemwiseAddForward(const nnvm::NodeAttrs& attrs, cons } } - std::vector in_prims; - std::vector in_pds; - in_prims.push_back(*dataA_mem); - in_prims.push_back(*dataB_mem); - in_pds.push_back(dataA_mem->get_primitive_desc()); - in_pds.push_back(dataB_mem->get_primitive_desc()); + std::vector in_desc; + in_desc.push_back(dataA_mem->get_desc()); + in_desc.push_back(dataB_mem->get_desc()); size_t i_ndim = in_data[quantized_elemwise_add_enum::kDataA].shape().ndim(); mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim); for (size_t i = 0; i < i_ndim; i++) { i_dims[i] = static_cast(in_data[quantized_elemwise_add_enum::kDataA].shape()[i]); } - mkldnn::memory::format i_fmt = static_cast( - in_pds[quantized_elemwise_add_enum::kDataA].desc().data.format); - auto output_desc = mkldnn::memory::desc(i_dims, output_data_type, i_fmt); - mkldnn::sum::primitive_desc pdesc(output_desc, scales, in_pds); + auto output_desc = dataA_mem->get_desc(); + output_desc.data.data_type = static_cast(output_data_type); + mkldnn::sum::primitive_desc pdesc(output_desc, scales, in_desc, engine); auto mem = CreateMKLDNNMem(out_data[quantized_elemwise_add_enum::kOut], - pdesc.dst_primitive_desc(), + pdesc.dst_desc(), req[0], &in_data[0]); + mkldnn_args_map_t args({{MKLDNN_ARG_MULTIPLE_SRC, *dataA_mem}, + {MKLDNN_ARG_MULTIPLE_SRC + 1, *dataB_mem}, + {MKLDNN_ARG_DST, *mem.second}}); MKLDNNStream *stream = MKLDNNStream::Get(); - stream->RegisterPrim(mkldnn::sum(pdesc, in_prims, *mem.second)); + stream->RegisterPrimArgs(mkldnn::sum(pdesc), args); CommitOutput(out_data[quantized_elemwise_add_enum::kOut], mem); stream->Submit(); @@ -203,4 +205,4 @@ NNVM_REGISTER_OP(_contrib_quantized_elemwise_add) } // namespace op } // namespace mxnet -#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_USE_MKLDNN == 100