Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.x] Improve activation backward (#17973) #18112

Open
wants to merge 1 commit into
base: v1.x
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,21 +249,28 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx
auto input_mem = in_buffer.GetMKLDNNData();
// We need to make sure the two inputs to eltwise_backward has the same memory
// descriptor. Otherwise, the perf will suffer.
if (input_mem->get_desc() != diff_dst_memory->get_desc())
if (input_mem->get_desc() != diff_dst_memory->get_desc()) {
input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_desc());
MKLDNNActBackward &bwd =
GetActBackward(param_, ctx, in_buffer, out_buffer, *input_mem);
}

MKLDNNActBackward &bwd = GetActBackward(param_, ctx, in_buffer, out_buffer, *input_mem);
MKLDNNStream *stream = MKLDNNStream::Get();
mkldnn_output_t diff_src_memory =
CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_desc(), req[0]);
mkldnn_args_map_t args = {
{ MKLDNN_ARG_SRC, *input_mem },
{ MKLDNN_ARG_DIFF_DST, *diff_dst_memory },
{ MKLDNN_ARG_DIFF_SRC, *diff_src_memory.second },
};
stream->RegisterPrimArgs(bwd.GetBwd(), args);
CommitOutput(in_grad, diff_src_memory);
stream->Submit();
mkldnn_args_map_t args = {{MKLDNN_ARG_SRC, *input_mem},
{MKLDNN_ARG_DIFF_DST, *diff_dst_memory}};
if (req[0] != kAddTo) {
// req[0] is kWriteTo or kWriteInplace
auto diff_src_memory =
const_cast<NDArray &>(in_grad).CreateMKLDNNData(bwd.bwd_pd.diff_src_desc());
args.insert({MKLDNN_ARG_DIFF_SRC, *diff_src_memory});
stream->RegisterPrimArgs(bwd.GetBwd(), args);
stream->Submit();
} else {
auto diff_src_memory = CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_desc(), req[0]);
args.insert({MKLDNN_ARG_DIFF_SRC, *diff_src_memory.second});
stream->RegisterPrimArgs(bwd.GetBwd(), args);
CommitOutput(in_grad, diff_src_memory);
stream->Submit();
}
}

void MKLDNNLeakyReluBackward(const nnvm::NodeAttrs& attrs,
Expand Down