Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[mkldnn-v1.0] Add MKL-DNN sum concat #16263

Merged
merged 1 commit into from
Oct 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions src/operator/nn/mkldnn/mkldnn_concat-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
/*!
* \file mkldnn_concat-inl.h
* \brief
* \author Wenting Jiang
* \author
*/
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_CONCAT_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_CONCAT_INL_H_


#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include <vector>
#include <utility>
#include "../concat-inl.h"
Expand All @@ -40,25 +40,20 @@ class MKLDNNConcatFwd {
public:
mkldnn::concat::primitive_desc fwd_pd;

MKLDNNConcatFwd(int concat_dim, const std::vector<mkldnn::memory::primitive_desc> &data_md)
: fwd_pd(concat_dim, data_md) {
data.resize(data_md.size());
MKLDNNConcatFwd(int concat_dim, const std::vector<mkldnn::memory::desc> &data_md)
: fwd_pd(concat_dim, data_md, CpuEngine::Get()->get_engine()) {
fwd_ = std::make_shared<mkldnn::concat>(fwd_pd);
}

void SetNewMem(const std::vector<const mkldnn::memory *> &in_data, const mkldnn::memory &output);

const mkldnn::concat &GetFwd() const;

private:
std::shared_ptr<mkldnn::concat> fwd;
std::vector<std::shared_ptr<mkldnn::memory>> data;
std::vector<mkldnn::primitive::at> data_mem;
std::shared_ptr<mkldnn::memory> out;
std::shared_ptr<mkldnn::concat> fwd_;
};

static MKLDNNConcatFwd &GetConcatForward(
int concat_dim, const std::vector<NDArray> &in_data,
const std::vector<mkldnn::memory::primitive_desc> &data_md) {
const std::vector<mkldnn::memory::desc> &data_md) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<OpSignature, MKLDNNConcatFwd, OpHash> fwds;
#else
Expand All @@ -79,5 +74,5 @@ static MKLDNNConcatFwd &GetConcatForward(
} // namespace op
} // namespace mxnet

#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_USE_MKLDNN == 100
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_CONCAT_INL_H_
82 changes: 33 additions & 49 deletions src/operator/nn/mkldnn/mkldnn_concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,62 +20,45 @@
/*!
* \file mkldnn_concat.cc
* \brief
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: A short brief may be preferred 😄

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ciyongch Can you help to add some brief ?

Copy link
Contributor Author

@rongzha1 rongzha1 Oct 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems just file head style, no mkldnn file add brief info

* \author Wenting Jiang
* \author
*/

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include "mkldnn_concat-inl.h"

namespace mxnet {
namespace op {

void MKLDNNConcatFwd::SetNewMem(const std::vector<const mkldnn::memory *> &in_data,
const mkldnn::memory &output) {
CHECK_EQ(in_data.size(), data.size());
for (size_t i = 0; i < data.size(); i++) {
if (this->data[i] == nullptr) {
this->data[i] = std::shared_ptr<mkldnn::memory>(
new mkldnn::memory(in_data[i]->get_primitive_desc(), in_data[i]->get_data_handle()));
this->data_mem.push_back(*this->data[i]);
} else {
this->data[i]->set_data_handle(in_data[i]->get_data_handle());
}
}
if (this->out == nullptr)
this->out = std::shared_ptr<mkldnn::memory>(
new mkldnn::memory(fwd_pd.dst_primitive_desc(), output.get_data_handle()));
else
this->out->set_data_handle(output.get_data_handle());

if (this->fwd == nullptr) fwd.reset(new mkldnn::concat(fwd_pd, data_mem, *out));
}

const mkldnn::concat &MKLDNNConcatFwd::GetFwd() const { return *fwd; }
const mkldnn::concat &MKLDNNConcatFwd::GetFwd() const { return *fwd_; }

void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data) {
TmpMemMgr::Get()->Init(ctx.requested[concat_enum::kTempSpace]);
const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
int num_in_data = param.num_args;
int concat_dim = param.dim;
std::vector<mkldnn::memory::primitive_desc> data_md;
const int num_in_data = param.num_args;
const int concat_dim = param.dim;
std::vector<mkldnn::memory::desc> data_md;
std::vector<const mkldnn::memory *> data_mem;
data_md.reserve(num_in_data);
data_mem.reserve(num_in_data);
for (int i = 0; i < num_in_data; i++) {
const mkldnn::memory *tmp_mem = in_data[i].GetMKLDNNData();
mkldnn::memory::primitive_desc tmp_pd = tmp_mem->get_primitive_desc();
data_md.push_back(tmp_pd);
mkldnn::memory::desc tmp_md = tmp_mem->get_desc();
data_md.push_back(tmp_md);
data_mem.push_back(tmp_mem);
}
MKLDNNConcatFwd &fwd = GetConcatForward(concat_dim, in_data, data_md);
mxnet::mkldnn_output_t out_mem = CreateMKLDNNMem(out_data[concat_enum::kOut],
fwd.fwd_pd.dst_primitive_desc(),
fwd.fwd_pd.dst_desc(),
req[concat_enum::kOut]);
fwd.SetNewMem(data_mem, *out_mem.second);
MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
std::unordered_map<int, mkldnn::memory> net_args;
net_args.insert({MKLDNN_ARG_DST, *out_mem.second});
for (int i = 0; i < num_in_data; i++) {
net_args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, *data_mem[i]});
}
MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
CommitOutput(out_data[concat_enum::kOut], out_mem);
MKLDNNStream::Get()->Submit();
}
Expand All @@ -86,11 +69,9 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray>& outputs) {
TmpMemMgr::Get()->Init(ctx.requested[concat_enum::kTempSpace]);
const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
int num_in_data = param.num_args;
int axis_ = param.dim;
auto engine = CpuEngine::Get()->get_engine();
auto gz_mem = inputs[0].GetMKLDNNData();
mkldnn::memory::primitive_desc gz_pd = gz_mem->get_primitive_desc();
const int num_in_data = param.num_args;
const int axis = param.dim;
const auto gradz_mem = inputs[0].GetMKLDNNData();
/* init the offset */
mkldnn::memory::dims offsets(outputs[0].shape().ndim());
for (auto &v : offsets) {
Expand All @@ -99,19 +80,22 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,

for (int i = 0; i < num_in_data; i++) {
mkldnn::memory::dims diff_src_tz(outputs[i].shape().begin(), outputs[i].shape().end());
auto diff_src_mpd = outputs[i].GetMKLDNNData()->get_primitive_desc();
auto gradi_mem_ = CreateMKLDNNMem(outputs[i], diff_src_mpd, req[i]);
// create view from gy to gxs[i]
std::shared_ptr<mkldnn::view::primitive_desc> view_pd;
view_pd.reset(new mkldnn::view::primitive_desc(gz_pd, diff_src_tz, offsets));
// create reorder primitive from gy to gxs[i]
mkldnn::reorder::primitive_desc reorder_pd(
view_pd.get()->dst_primitive_desc(), diff_src_mpd);
offsets[axis_] += diff_src_tz[axis_];
MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(
reorder_pd, *gz_mem, *gradi_mem_.second));
CommitOutput(outputs[i], gradi_mem_);
auto diff_src_md = outputs[i].GetMKLDNNData()->get_desc();
auto gradi_mem = CreateMKLDNNMem(outputs[i], diff_src_md, req[i]);

auto from_md = gradz_mem->get_desc().submemory_desc(diff_src_tz, offsets);
auto from_mem = new mkldnn::memory(from_md, gradz_mem->get_engine(),
gradz_mem->get_data_handle());
offsets[axis] += diff_src_tz[axis];

std::unordered_map<int, mkldnn::memory> net_args({
{MKLDNN_ARG_FROM, *gradz_mem},
{MKLDNN_ARG_TO, *gradi_mem.second}
});
MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::reorder(*from_mem, *gradi_mem.second), net_args);
CommitOutput(outputs[i], gradi_mem);
}

MKLDNNStream::Get()->Submit();
}

Expand Down
11 changes: 5 additions & 6 deletions src/operator/nn/mkldnn/mkldnn_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,15 @@
*/

/*!
* \file mkldnn_softmax.cc
* \file mkldnn_copy.cc
* \brief
* \author Da Zheng
* \author
*/

#include "../softmax-inl.h"
#include "./mkldnn_ops-inl.h"
#include "./mkldnn_base-inl.h"

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
namespace mxnet {
namespace op {

Expand All @@ -47,9 +46,9 @@ void MKLDNNCopy(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
// We should try and force the input memory has the same format
// as the input output. If not, we'll have to reorder memory.
auto out_mem = out_data.GetMKLDNNData();
in_mem = data.GetMKLDNNData(out_mem ->get_primitive_desc());
in_mem = data.GetMKLDNNData(out_mem ->get_desc());
if (in_mem == nullptr)
in_mem = data.GetMKLDNNDataReorder(out_mem->get_primitive_desc());
in_mem = data.GetMKLDNNDataReorder(out_mem->get_desc());
MKLDNNSum(*out_mem, *in_mem, *out_mem);
} else {
const_cast<NDArray &>(out_data).CopyFrom(*in_mem);
Expand Down
48 changes: 30 additions & 18 deletions src/operator/nn/mkldnn/mkldnn_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,17 @@ namespace mxnet {
namespace op {

#if MXNET_USE_MKLDNN == 1
/* For sum */
void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs, const OpReqType &req,
const NDArray &out_data);

/* For copy */
void MKLDNNCopy(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);
void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const NDArray &input,
const OpReqType &req,
const NDArray &output);

/* For concat */
void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);
void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);
void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &input,
const OpReqType &req,
const NDArray &output);
#endif

#if MXNET_USE_MKLDNN == 100
Expand Down Expand Up @@ -122,6 +114,26 @@ void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, const OpContext &c
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);

/* For sum */
void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs, const OpReqType &req,
const NDArray &out_data);

/* For copy */
void MKLDNNCopy(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);

/* For concat */
void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);
void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);

void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2,
const mkldnn::memory &out);

Expand Down
61 changes: 17 additions & 44 deletions src/operator/nn/mkldnn/mkldnn_sum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,41 +54,33 @@ void MKLDNNSum(const mkldnn::memory &arr1,
in_mem2 = tmp_memory2;
}
mkldnn::sum::primitive_desc sum_pd(output_pd, scales, input_pds, CpuEngine::Get()->get_engine());
std::unordered_map<int, mkldnn::memory> args = {
mkldnn_args_map_t args = {
{ MKLDNN_ARG_MULTIPLE_SRC, *in_mem1 },
{ MKLDNN_ARG_MULTIPLE_SRC + 1, *in_mem2 },
{ MKLDNN_ARG_DST, out },
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also needs to avoid temporary pairs here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this less efficient than serveral assignment.

MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::sum(sum_pd), args);
}

#endif

#if MXNET_USE_MKLDNN == 1
class MKLDNNSumFwd {
public:
mkldnn::sum::primitive_desc fwd_pd;

MKLDNNSumFwd(const std::vector<float> &scales,
const std::vector<mkldnn::memory::primitive_desc> &data_md)
: fwd_pd(scales, data_md) {
data_.resize(data_md.size());
const std::vector<mkldnn::memory::desc> &data_md)
: fwd_pd(scales, data_md, CpuEngine::Get()->get_engine()) {
fwd_ = std::make_shared<mkldnn::sum>(fwd_pd);
}

void SetNewMem(const std::vector<const mkldnn::memory *> &in_data, const mkldnn::memory &output);

const mkldnn::sum &GetFwd() const { return *fwd_; }

private:
std::shared_ptr<mkldnn::sum> fwd_;
std::vector<std::shared_ptr<mkldnn::memory>> data_;
std::vector<mkldnn::primitive::at> data_mem_;
std::shared_ptr<mkldnn::memory> out_;
};

static MKLDNNSumFwd &GetSumForward(
const std::vector<float> &scales, const std::vector<NDArray> &in_data,
const std::vector<mkldnn::memory::primitive_desc> &data_md) {
const std::vector<mkldnn::memory::desc> &data_md) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<OpSignature, MKLDNNSumFwd, OpHash> fwds;
#else
Expand All @@ -105,43 +97,20 @@ static MKLDNNSumFwd &GetSumForward(
return it->second;
}

void MKLDNNSumFwd::SetNewMem(const std::vector<const mkldnn::memory *> &in_data,
const mkldnn::memory &output) {
auto num_inputs = data_.size();
CHECK_EQ(in_data.size(), num_inputs);
for (index_t i = 0; i < static_cast<index_t>(num_inputs); ++i) {
if (this->data_[i] == nullptr) {
this->data_[i] = std::shared_ptr<mkldnn::memory>(
new mkldnn::memory(in_data[i]->get_primitive_desc(), in_data[i]->get_data_handle()));
this->data_mem_.push_back(*this->data_[i]);
} else {
this->data_[i]->set_data_handle(in_data[i]->get_data_handle());
}
}
if (this->out_ == nullptr)
this->out_ = std::shared_ptr<mkldnn::memory>(
new mkldnn::memory(fwd_pd.dst_primitive_desc(), output.get_data_handle()));
else
this->out_->set_data_handle(output.get_data_handle());

if (this->fwd_ == nullptr)
this->fwd_.reset(new mkldnn::sum(fwd_pd, this->data_mem_, *this->out_));
}

void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs, const OpReqType &req,
const NDArray &out_data) {
TmpMemMgr::Get()->Init(ctx.requested[0]);
auto num_inputs = inputs.size();
std::vector<mkldnn::memory::primitive_desc> data_md;
const int num_inputs = inputs.size();
std::vector<mkldnn::memory::desc> data_md;
std::vector<const mkldnn::memory *> data_mem;
std::vector<float> scales(num_inputs, 1);
std::vector<NDArray> in_bufs(num_inputs);

data_md.reserve(num_inputs);
data_mem.reserve(num_inputs);

for (index_t i = 0; i < static_cast<index_t>(num_inputs); ++i) {
for (int i = 0; i < num_inputs; ++i) {
const mkldnn::memory *in_mem;
if (inputs[i].IsMKLDNNData() && inputs[i].IsView()) {
in_bufs[i] = inputs[i].Reorder2Default();
Expand All @@ -150,18 +119,22 @@ void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
in_bufs[i] = inputs[i];
in_mem = inputs[i].GetMKLDNNData();
}
mkldnn::memory::primitive_desc tmp_pd = in_mem->get_primitive_desc();
data_md.push_back(tmp_pd);
mkldnn::memory::desc tmp_md = in_mem->get_desc();
data_md.push_back(tmp_md);
data_mem.push_back(in_mem);
}

MKLDNNSumFwd &fwd = GetSumForward(scales, in_bufs, data_md);
mxnet::mkldnn_output_t out_mem = CreateMKLDNNMem(out_data,
fwd.fwd_pd.dst_primitive_desc(),
fwd.fwd_pd.dst_desc(),
req,
&in_bufs[0]);
fwd.SetNewMem(data_mem, *out_mem.second);
MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
mkldnn_args_map_t net_args;
net_args.insert({MKLDNN_ARG_DST, *out_mem.second});
for (int i = 0; i < num_inputs; ++i) {
net_args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, *data_mem[i]});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZhennanQin has recommended a better insert approach, see #16199 (comment).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK thanks

}
MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
CommitOutput(out_data, out_mem);
MKLDNNStream::Get()->Submit();
}
Expand Down
Loading