Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[mkldnn-1.0] upgrade int8 concat to MKLDNN1.0 (#16466)
Browse files Browse the repository at this point in the history
* [mkldnn-1.0] upgrade int8 concat to MKLDNN1.0

* fix lint

* use mkldnn_args_map_t

* update dict usage style

* retrigger CI

* retrigger CI again

* retrigger CI again 2
  • Loading branch information
ElaineBao authored and pengzhao-intel committed Oct 15, 2019
1 parent 43e35a9 commit 4d9a53e
Showing 1 changed file with 23 additions and 22 deletions.
45 changes: 23 additions & 22 deletions src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
* \brief
*/

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include "../../nn/mkldnn/mkldnn_concat-inl.h"
#include "../quantization_utils.h"

Expand Down Expand Up @@ -60,7 +60,7 @@ static void MKLDNNQuantizedConcatForward(const nnvm::NodeAttrs& attrs, const OpC
out_data[quantized_concat_enum::kMin].data().dptr<float>()[0] = output_neg_min;
out_data[quantized_concat_enum::kMax].data().dptr<float>()[0] = output_pos_max;
auto out_scale = GetScale(out_data[quantized_concat_enum::kOut], output_neg_min, output_pos_max);
std::vector<mkldnn::memory::primitive_desc> data_md;
std::vector<mkldnn::memory::desc> data_md;
std::vector<const mkldnn::memory*> data_mem;
// new_data_mem is for auto-free new created mkldnn memory
std::vector<std::shared_ptr<mkldnn::memory>> new_data_mem;
Expand All @@ -71,36 +71,37 @@ static void MKLDNNQuantizedConcatForward(const nnvm::NodeAttrs& attrs, const OpC
CHECK(in_data[i].dtype() == out_dtype);
auto mem = in_data[i].GetMKLDNNData();
data_mem.push_back(mem);
data_md.push_back(mem->get_primitive_desc());
data_md.push_back(mem->get_desc());
} else {
auto mem = in_data[i].GetMKLDNNData();
auto pd = mem->get_primitive_desc();
auto mem_desc = mem->get_desc();
if (in_data[i].dtype() != out_dtype) {
auto mem_desc = pd.desc();
mkldnn::memory::desc new_md(
mkldnn::memory::dims(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims),
get_mkldnn_type(out_dtype), static_cast<mkldnn::memory::format>(mem_desc.data.format));
pd = mkldnn::memory::primitive_desc(new_md, CpuEngine::Get()->get_engine());
mem_desc.data.data_type = static_cast<mkldnn_data_type_t>(get_mkldnn_type(out_dtype));
}
const auto rescaled_mem = std::make_shared<mkldnn::memory>(pd);
const auto rescaled_mem =
std::make_shared<mkldnn::memory>(mem_desc, CpuEngine::Get()->get_engine());
new_data_mem.push_back(rescaled_mem);
std::vector<float> reorder_scale = {out_scale / i_scale};
primitive_attr reorder_attr;
reorder_attr.set_int_output_round_mode(round_mode::round_nearest);
mkldnn::primitive_attr reorder_attr;
reorder_attr.set_output_scales(0, reorder_scale);
const auto reorder_pd =
mkldnn::reorder::primitive_desc(mem->get_primitive_desc(), pd, reorder_attr);
MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *mem, *rescaled_mem));
const auto reorder_pd = mkldnn::reorder::primitive_desc(*mem, *rescaled_mem, reorder_attr);
mkldnn_args_map_t reorder_args;
reorder_args[MKLDNN_ARG_SRC] = *mem;
reorder_args[MKLDNN_ARG_DST] = *rescaled_mem;
MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::reorder(reorder_pd), reorder_args);
data_mem.push_back(rescaled_mem.get());
data_md.push_back(pd);
data_md.push_back(mem_desc);
}
}
MKLDNNConcatFwd& fwd = GetConcatForward(param_.dim, in_data, data_md);
mxnet::mkldnn_output_t out_mem =
CreateMKLDNNMem(out_data[quantized_concat_enum::kOut], fwd.fwd_pd.dst_primitive_desc(),
req[concat_enum::kOut]);
fwd.SetNewMem(data_mem, *out_mem.second);
MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
mxnet::mkldnn_output_t out_mem = CreateMKLDNNMem(out_data[quantized_concat_enum::kOut],
fwd.fwd_pd.dst_desc(), req[concat_enum::kOut]);
mkldnn_args_map_t net_args;
net_args[MKLDNN_ARG_DST] = *out_mem.second;
for (int i = 0; i < param_.num_args; i++) {
net_args[MKLDNN_ARG_MULTIPLE_SRC + i] = *data_mem[i];
}
MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
CommitOutput(out_data[concat_enum::kOut], out_mem);
MKLDNNStream::Get()->Submit();
}
Expand All @@ -126,4 +127,4 @@ NNVM_REGISTER_OP(_contrib_quantized_concat)
} // namespace op
} // namespace mxnet

#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_USE_MKLDNN == 100

0 comments on commit 4d9a53e

Please sign in to comment.