Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MKL-DNN] Integrate Conv3d and Pool3d/1d (#17884)
Browse files Browse the repository at this point in the history
* Integrate MKl-DNN conv3d and pool3d/1d

* fix UT & address comments

* clean code

* rebase against latest master
  • Loading branch information
wuxun-zhang committed Apr 15, 2020
1 parent 6d8b679 commit 57dc78d
Show file tree
Hide file tree
Showing 15 changed files with 492 additions and 279 deletions.
8 changes: 4 additions & 4 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ bool SupportMKLDNNAct(const ActivationParam& param) {
}

bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input) {
// MKL-DNN Activation supports 1d, 2d, 3d, 4d data layout
// MKL-DNN Activation supports 1d, 2d, 3d, 4d and 5d data layout
if ((input.shape().ndim() < 1) ||
(input.shape().ndim() > 4) ||
(input.shape().ndim() > 5) ||
!(input.dtype() == mshadow::kFloat32 || input.dtype() == mshadow::kBfloat16))
return false;
return SupportMKLDNNAct(param);
Expand All @@ -63,9 +63,9 @@ bool SupportMKLDNNLeakyRelu(const LeakyReLUParam& param) {
}

bool SupportMKLDNNLeakyRelu(const LeakyReLUParam& param, const NDArray &input) {
// MKL-DNN Activation supports 1d, 2d, 3d, 4d data layout
// MKL-DNN Activation supports 1d, 2d, 3d, 4d and 5d data layout
if ((input.shape().ndim() < 1) ||
(input.shape().ndim() > 4) ||
(input.shape().ndim() > 5) ||
!(input.dtype() == mshadow::kFloat32 || input.dtype() == mshadow::kBfloat16))
return false;
return SupportMKLDNNLeakyRelu(param);
Expand Down
41 changes: 26 additions & 15 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,8 @@ static inline bool SupportMKLDNN(int dtype, const mxnet::TShape &shape) {
// MKLDNN currently does not support 0-dim Tensor and 0-size Tensor
return false;
}

return (dtype == mshadow::kFloat32 || dtype == mshadow::kBfloat16) &&
(ndim == 1 || ndim == 2 || ndim == 4);
(ndim == 1 || ndim == 2 || ndim == 4);
}

static inline bool SupportMKLDNNRnn(const NDArray &input) {
Expand Down Expand Up @@ -332,20 +331,32 @@ inline static mkldnn::memory::desc GetWeightDesc(const NDArray &arr,
if (num_groups == 1) {
return GetMemDesc(arr, dtype);
} else {
auto ndim = arr.shape().ndim();
CHECK((ndim == 3) || (ndim == 4))
<< "MKL-DNN weight currectly supports 3d and 4d layout";
const auto ndim = arr.shape().ndim();
CHECK((ndim == 3) || (ndim == 4) || (ndim == 5))
<< "MKL-DNN weight currently supports 3d or 4d or 5d layout";
auto tz = mkldnn::memory::dims{0};
const int N = 0, H = 2, W = 3, C = 1;
if (ndim == 3) {
tz = mkldnn::memory::dims{
num_groups, static_cast<int>(arr.shape()[N] / num_groups),
static_cast<int>(arr.shape()[C]), static_cast<int>(arr.shape()[H])};
} else {
tz = mkldnn::memory::dims{
num_groups, static_cast<int>(arr.shape()[N] / num_groups),
static_cast<int>(arr.shape()[C]), static_cast<int>(arr.shape()[H]),
static_cast<int>(arr.shape()[W])};
int N = 0, C = 1, H = 2, W = 3;
int D = -1;
if (ndim == 5) {
D = 2;
H = 3;
W = 4;
}
switch (ndim) {
case 3:
tz = mkldnn::memory::dims{
num_groups, arr.shape()[N] / num_groups,
arr.shape()[C], arr.shape()[H]};
break;
case 4:
tz = mkldnn::memory::dims{
num_groups, arr.shape()[N] / num_groups,
arr.shape()[C], arr.shape()[H], arr.shape()[W]};
break;
case 5:
tz = mkldnn::memory::dims{
num_groups, arr.shape()[N] / num_groups,
arr.shape()[C], arr.shape()[D], arr.shape()[H], arr.shape()[W]};
}
return mkldnn::memory::desc{tz, get_mkldnn_type(dtype), mkldnn::memory::format_tag::any};
}
Expand Down
47 changes: 30 additions & 17 deletions src/operator/nn/mkldnn/mkldnn_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,31 +240,44 @@ const mkldnn::memory *GetWeights(const NDArray &arr, int num_groups) {
auto tz = mkldnn::memory::dims{0};
auto format_tag = mkldnn::memory::format_tag::undef;
auto engine = CpuEngine::Get()->get_engine();
const int O = 0, I = 1, H = 2, W = 3;
if (arr.shape().ndim() == 2) {
tz = mkldnn::memory::dims{static_cast<int>(arr.shape()[O]), static_cast<int>(arr.shape()[I])};
const int ndim = arr.shape().ndim();
int O = 0, I = 1, H = 2, W = 3;
int D = -1;
if (ndim == 5) {
D = 2;
H = 3;
W = 4;
}
if (ndim == 2) {
tz = mkldnn::memory::dims{arr.shape()[O], arr.shape()[I]};
format_tag = mkldnn::memory::format_tag::oi;
} else if (arr.shape().ndim() == 3) {
} else if (ndim == 3) {
tz = num_groups > 1
? mkldnn::memory::dims{num_groups, static_cast<int>(arr.shape()[O] / num_groups),
static_cast<int>(arr.shape()[I]),
static_cast<int>(arr.shape()[H])}
: mkldnn::memory::dims{static_cast<int>(arr.shape()[O]),
static_cast<int>(arr.shape()[I]),
static_cast<int>(arr.shape()[H])};
? mkldnn::memory::dims{num_groups, arr.shape()[O] / num_groups,
arr.shape()[I], arr.shape()[H]}
: mkldnn::memory::dims{arr.shape()[O],
arr.shape()[I], arr.shape()[H]};
format_tag = num_groups > 1 ? mkldnn::memory::format_tag::goiw
: mkldnn::memory::format_tag::oiw;
} else if (arr.shape().ndim() == 4) {
} else if (ndim == 4) {
tz = num_groups > 1
? mkldnn::memory::dims{num_groups, static_cast<int>(arr.shape()[O] / num_groups),
static_cast<int>(arr.shape()[I]),
static_cast<int>(arr.shape()[H]),
static_cast<int>(arr.shape()[W])}
? mkldnn::memory::dims{num_groups, arr.shape()[O] / num_groups,
arr.shape()[I], arr.shape()[H],
arr.shape()[W]}
: mkldnn::memory::dims{
static_cast<int>(arr.shape()[O]), static_cast<int>(arr.shape()[I]),
static_cast<int>(arr.shape()[H]), static_cast<int>(arr.shape()[W])};
arr.shape()[O], arr.shape()[I], arr.shape()[H], arr.shape()[W]};
format_tag = num_groups > 1 ? mkldnn::memory::format_tag::goihw
: mkldnn::memory::format_tag::oihw;
} else if (ndim == 5) {
tz = num_groups > 1
? mkldnn::memory::dims{num_groups, arr.shape()[O] / num_groups,
arr.shape()[I], arr.shape()[D],
arr.shape()[H], arr.shape()[W]}
: mkldnn::memory::dims{
arr.shape()[O], arr.shape()[I], arr.shape()[D],
arr.shape()[H], arr.shape()[W]};
format_tag = num_groups > 1 ? mkldnn::memory::format_tag::goidhw
: mkldnn::memory::format_tag::oidhw;
} else {
LOG(FATAL) << "The weight array has an unsupported number of dimensions";
}
Expand Down
60 changes: 52 additions & 8 deletions src/operator/nn/mkldnn/mkldnn_convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,13 @@ DMLC_REGISTER_PARAMETER(MKLDNNConvParam);

bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input) {
if ((params.kernel.ndim() != 1) &&
(params.kernel.ndim() != 2))
(params.kernel.ndim() != 2) &&
(params.kernel.ndim() != 3))
return false;
return SupportMKLDNNQuantize(input.dtype()) &&
((input.shape().ndim() == 3) ||
(input.shape().ndim() == 4));
(input.shape().ndim() == 4) ||
(input.shape().ndim() == 5));
}

std::shared_ptr<mkldnn::convolution_forward::primitive_desc> GetConvFwdImpl(
Expand Down Expand Up @@ -77,9 +79,19 @@ std::shared_ptr<mkldnn::convolution_forward::primitive_desc> GetConvFwdImpl(
strides[1] = param.conv_param.stride[1];
padding[0] = param.conv_param.pad[0];
padding[1] = param.conv_param.pad[1];
} else if (param.conv_param.kernel.ndim() == 3) {
CHECK_GE(param.conv_param.stride.ndim(), 3);
CHECK_GE(param.conv_param.pad.ndim(), 3);
CHECK_GE(param.conv_param.dilate.ndim(), 3);
strides[0] = param.conv_param.stride[0];
strides[1] = param.conv_param.stride[1];
strides[2] = param.conv_param.stride[2];
padding[0] = param.conv_param.pad[0];
padding[1] = param.conv_param.pad[1];
padding[2] = param.conv_param.pad[2];
} else {
LOG(FATAL) << "Unexpected MKL-DNN Conv kernel size "
<< param.conv_param.kernel.ndim() << ", supporting only 1 or 2.";
<< param.conv_param.kernel.ndim() << ", supporting only 1 or 2 or 3.";
}
mkldnn::primitive_attr attr;
mkldnn::post_ops ops;
Expand Down Expand Up @@ -141,9 +153,13 @@ std::shared_ptr<mkldnn::convolution_forward::primitive_desc> GetConvFwdImpl(
} else if (param.conv_param.dilate.ndim() == 2) {
dilates[0] = param.conv_param.dilate[0] - 1;
dilates[1] = param.conv_param.dilate[1] - 1;
} else if (param.conv_param.dilate.ndim() == 3) {
dilates[0] = param.conv_param.dilate[0] - 1;
dilates[1] = param.conv_param.dilate[1] - 1;
dilates[2] = param.conv_param.dilate[2] - 1;
} else {
LOG(FATAL) << "Unexpected MKL-DNN Conv dilate size " << param.conv_param.dilate.ndim()
<< ", supporting only 1 or 2.";
<< ", supporting only 1 or 2 or 3.";
}
if (bias_md_ptr == nullptr) {
mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md,
Expand Down Expand Up @@ -181,9 +197,19 @@ static std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc> GetCon
strides[1] = param.stride[1];
padding[0] = param.pad[0];
padding[1] = param.pad[1];
} else if (param.kernel.ndim() == 3) {
CHECK_GE(param.stride.ndim(), 3);
CHECK_GE(param.pad.ndim(), 3);
CHECK_GE(param.dilate.ndim(), 3);
strides[0] = param.stride[0];
strides[1] = param.stride[1];
strides[2] = param.stride[2];
padding[0] = param.pad[0];
padding[1] = param.pad[1];
padding[2] = param.pad[2];
} else {
LOG(FATAL) << "Unexpected MKL-DNN Conv kernel size " << param.kernel.ndim()
<< ", supporting only 1 or 2.";
<< ", supporting only 1 or 2 or 3.";
}

auto GetConvBwdDataPd = [&data, &weight, &output,
Expand Down Expand Up @@ -216,9 +242,13 @@ static std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc> GetCon
} else if (param.dilate.ndim() == 2) {
dilates[0] = param.dilate[0] - 1;
dilates[1] = param.dilate[1] - 1;
} else if (param.dilate.ndim() == 3) {
dilates[0] = param.dilate[0] - 1;
dilates[1] = param.dilate[1] - 1;
dilates[2] = param.dilate[2] - 1;
} else {
LOG(FATAL) << "Unexpected MKL-DNN Conv dilate size "
<< param.dilate.ndim() << ", supporting only 1 or 2.";
<< param.dilate.ndim() << ", supporting only 1 or 2 or 3.";
}
mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, data_md,
weight_md, out_md, strides, dilates, padding,
Expand Down Expand Up @@ -250,9 +280,19 @@ static std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc> Get
strides[1] = param.stride[1];
padding[0] = param.pad[0];
padding[1] = param.pad[1];
} else if (param.kernel.ndim() == 3) {
CHECK_GE(param.stride.ndim(), 3);
CHECK_GE(param.pad.ndim(), 3);
CHECK_GE(param.dilate.ndim(), 3);
strides[0] = param.stride[0];
strides[1] = param.stride[1];
strides[2] = param.stride[2];
padding[0] = param.pad[0];
padding[1] = param.pad[1];
padding[2] = param.pad[2];
} else {
LOG(FATAL) << "Unexpected MKL-DNN Conv kernel size " << param.kernel.ndim()
<< ", supporting only 1 or 2.";
<< ", supporting only 1 or 2 or 3.";
}

auto GetConvBwdWeightsPd = [&data, &weight, &output,
Expand Down Expand Up @@ -291,9 +331,13 @@ static std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc> Get
} else if (param.dilate.ndim() == 2) {
dilates[0] = param.dilate[0] - 1;
dilates[1] = param.dilate[1] - 1;
} else if (param.dilate.ndim() == 3) {
dilates[0] = param.dilate[0] - 1;
dilates[1] = param.dilate[1] - 1;
dilates[2] = param.dilate[2] - 1;
} else {
LOG(FATAL) << "Unexpected MKL-DNN Conv dilate size "
<< param.dilate.ndim() << ", supporting only 1 or 2.";
<< param.dilate.ndim() << ", supporting only 1 or 2 or 3.";
}
if (bias == nullptr) {
mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct,
Expand Down
59 changes: 36 additions & 23 deletions src/operator/nn/mkldnn/mkldnn_pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,15 @@ class MKLDNNPoolingFwd {
public:
MKLDNNPoolingFwd(const mxnet::NDArray &input,
const mxnet::NDArray &output,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int padding_t, const int padding_b,
const int padding_l, const int padding_r,
const mkldnn::memory::dims &kernel,
const mkldnn::memory::dims &strides,
const mkldnn::memory::dims &pad_l,
const mkldnn::memory::dims &pad_r,
const mkldnn::algorithm alg_kind,
const bool with_workspace, const bool is_train):
with_workspace_(with_workspace),
fwd_(nullptr) {
Init(input, output,
kernel_h, kernel_w, stride_h, stride_w,
padding_t, padding_b, padding_l, padding_r,
Init(input, output, kernel, strides, pad_l, pad_r,
is_train, alg_kind);
}

Expand All @@ -67,10 +65,10 @@ class MKLDNNPoolingFwd {
private:
void Init(const mxnet::NDArray &input,
const mxnet::NDArray &output,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int padding_t, const int padding_b,
const int padding_l, const int padding_r,
const mkldnn::memory::dims &kernel,
const mkldnn::memory::dims &strides,
const mkldnn::memory::dims &pad_l,
const mkldnn::memory::dims &pad_r,
const bool is_train, const mkldnn::algorithm alg_kind);
};

Expand Down Expand Up @@ -98,31 +96,46 @@ inline int GetPaddingSizeFull(dim_t x, int padl, int padr, int k, int s) {
}

inline bool SupportMKLDNNPooling(const PoolingParam &param) {
return param.kernel.ndim() == 2 &&
return (param.kernel.ndim() == 1 || param.kernel.ndim() == 2 ||
param.kernel.ndim() == 3) &&
(param.pool_type == pool_enum::kMaxPooling ||
param.pool_type == pool_enum::kAvgPooling) &&
(!param.layout.has_value() || param.layout.value() == mshadow::kNCHW);
(!param.layout.has_value() ||
(param.layout.value() == mshadow::kNCW || param.layout.value() == mshadow::kNCHW ||
param.layout.value() == mshadow::kNCDHW));
}

inline bool SupportMKLDNNPooling(const PoolingParam &param,
const mxnet::TShape &dshape) {
bool ret = SupportMKLDNNPooling(param);
if (!ret)
const NDArray &input) {
const auto dshape = input.shape();
const auto ndim = dshape.ndim();
const auto dtype = input.dtype();

if (!(SupportStorageMKLDNN(input.storage_type()) && (ndim == 3 || ndim == 4 || ndim == 5) &&
(dtype == mshadow::kFloat32 || dtype == mshadow::kBfloat16)))
return false;

if (!SupportMKLDNNPooling(param))
return false;

if (param.pooling_convention == pool_enum::kValid) {
return true;
} else {
if (param.pool_type == pool_enum::kAvgPooling) {
CHECK_EQ(dshape.ndim(), 4);
// mkldnn works differently when padding is asymmetric, so let's skip this case.
if (param.pad[0] == GetPaddingSizeFull(dshape[2], param.pad[0], param.pad[0], param.kernel[0],
param.stride[0]) &&
param.pad[1] == GetPaddingSizeFull(dshape[3], param.pad[1], param.pad[1], param.kernel[1],
param.stride[1])) {
return true;
bool is_symmetric = true;
switch (ndim) {
case 5:
is_symmetric = is_symmetric && (param.pad[2] == GetPaddingSizeFull(dshape[4],
param.pad[2], param.pad[2], param.kernel[2], param.stride[2]));
case 4:
is_symmetric = is_symmetric && (param.pad[1] == GetPaddingSizeFull(dshape[3],
param.pad[1], param.pad[1], param.kernel[1], param.stride[1]));
case 3:
is_symmetric = is_symmetric && (param.pad[0] == GetPaddingSizeFull(dshape[2],
param.pad[0], param.pad[0], param.kernel[0], param.stride[0]));
}
return false;
return is_symmetric;
}
return param.pool_type == pool_enum::kMaxPooling;
}
Expand Down
Loading

0 comments on commit 57dc78d

Please sign in to comment.