diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc index 2a4c6d612e65..88a57071fa0e 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc @@ -23,7 +23,7 @@ * \brief */ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "../../nn/mkldnn/mkldnn_concat-inl.h" #include "../quantization_utils.h" @@ -60,7 +60,7 @@ static void MKLDNNQuantizedConcatForward(const nnvm::NodeAttrs& attrs, const OpC out_data[quantized_concat_enum::kMin].data().dptr()[0] = output_neg_min; out_data[quantized_concat_enum::kMax].data().dptr()[0] = output_pos_max; auto out_scale = GetScale(out_data[quantized_concat_enum::kOut], output_neg_min, output_pos_max); - std::vector data_md; + std::vector data_md; std::vector data_mem; // new_data_mem is for auto-free new created mkldnn memory std::vector> new_data_mem; @@ -71,36 +71,37 @@ static void MKLDNNQuantizedConcatForward(const nnvm::NodeAttrs& attrs, const OpC CHECK(in_data[i].dtype() == out_dtype); auto mem = in_data[i].GetMKLDNNData(); data_mem.push_back(mem); - data_md.push_back(mem->get_primitive_desc()); + data_md.push_back(mem->get_desc()); } else { auto mem = in_data[i].GetMKLDNNData(); - auto pd = mem->get_primitive_desc(); + auto mem_desc = mem->get_desc(); if (in_data[i].dtype() != out_dtype) { - auto mem_desc = pd.desc(); - mkldnn::memory::desc new_md( - mkldnn::memory::dims(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims), - get_mkldnn_type(out_dtype), static_cast(mem_desc.data.format)); - pd = mkldnn::memory::primitive_desc(new_md, CpuEngine::Get()->get_engine()); + mem_desc.data.data_type = static_cast(get_mkldnn_type(out_dtype)); } - const auto rescaled_mem = std::make_shared(pd); + const auto rescaled_mem = + std::make_shared(mem_desc, CpuEngine::Get()->get_engine()); new_data_mem.push_back(rescaled_mem); std::vector reorder_scale = {out_scale / i_scale}; - primitive_attr reorder_attr; - reorder_attr.set_int_output_round_mode(round_mode::round_nearest); + mkldnn::primitive_attr reorder_attr; reorder_attr.set_output_scales(0, reorder_scale); - const auto reorder_pd = - mkldnn::reorder::primitive_desc(mem->get_primitive_desc(), pd, reorder_attr); - MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *mem, *rescaled_mem)); + const auto reorder_pd = mkldnn::reorder::primitive_desc(*mem, *rescaled_mem, reorder_attr); + mkldnn_args_map_t reorder_args; + reorder_args[MKLDNN_ARG_SRC] = *mem; + reorder_args[MKLDNN_ARG_DST] = *rescaled_mem; + MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::reorder(reorder_pd), reorder_args); data_mem.push_back(rescaled_mem.get()); - data_md.push_back(pd); + data_md.push_back(mem_desc); } } MKLDNNConcatFwd& fwd = GetConcatForward(param_.dim, in_data, data_md); - mxnet::mkldnn_output_t out_mem = - CreateMKLDNNMem(out_data[quantized_concat_enum::kOut], fwd.fwd_pd.dst_primitive_desc(), - req[concat_enum::kOut]); - fwd.SetNewMem(data_mem, *out_mem.second); - MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + mxnet::mkldnn_output_t out_mem = CreateMKLDNNMem(out_data[quantized_concat_enum::kOut], + fwd.fwd_pd.dst_desc(), req[concat_enum::kOut]); + mkldnn_args_map_t net_args; + net_args[MKLDNN_ARG_DST] = *out_mem.second; + for (int i = 0; i < param_.num_args; i++) { + net_args[MKLDNN_ARG_MULTIPLE_SRC + i] = *data_mem[i]; + } + MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args); CommitOutput(out_data[concat_enum::kOut], out_mem); MKLDNNStream::Get()->Submit(); } @@ -126,4 +127,4 @@ NNVM_REGISTER_OP(_contrib_quantized_concat) } // namespace op } // namespace mxnet -#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_USE_MKLDNN == 100