Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-33] Enhance mkldnn pooling to support full convention #11047

Merged
merged 18 commits into from
Nov 17, 2018
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions src/operator/nn/mkldnn/mkldnn_pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,23 +87,8 @@ inline bool SupportMKLDNNPooling(const PoolingParam &param) {
inline bool SupportMKLDNNPooling(const PoolingParam &param,
const TShape &dshape) {
bool ret = SupportMKLDNNPooling(param);
if (!ret)
return false;

if (param.pooling_convention == pool_enum::kValid)
return true;
else
return false;

// need to support pooling convention full
// https://issues.apache.org/jira/browse/MXNET-33
#if 0
if (((dshape[2] + 2 * param.pad[0] - param.kernel[0]) % param.stride[0] == 0) &&
((dshape[3] + 2 * param.pad[1] - param.kernel[1]) % param.stride[1] == 0))
return true;
else
return false;
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we ignore the shape completely? even for the case of kValid?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think so. Previously, mkldnn pooling operator only supports pooling_convention=kValid and it's no need to check shape for kValid. But if we want to support kFull, we need adjust padding size to get correct output shape.

return ret;
}

inline bool MKLDNNRequireWorkspace(const PoolingParam &param) {
Expand Down
31 changes: 31 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,16 @@ mkldnn::pooling_forward::primitive_desc GetPoolingFwd(const PoolingParam &param,
int pad_l_ = param.pad[1], pad_r_ = param.pad[1];
int stride_h_ = param.stride[0], stride_w_ = param.stride[1];

if (param.pooling_convention == pool_enum::kFull) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to write up a macro/function for the same check of kFull in the different place in this file?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if ((data_md.data.dims[2] + pad_t_ + pad_b_ - kernel_h_) % stride_h_ != 0) {
pad_b_ += stride_h_ - ((data_md.data.dims[2] + pad_t_ + pad_b_ - kernel_h_) % stride_h_);
}

if ((data_md.data.dims[3] + pad_l_ + pad_r_ - kernel_w_) % stride_w_ != 0) {
pad_r_ += stride_w_ - ((data_md.data.dims[3] + pad_l_ + pad_r_ - kernel_w_) % stride_w_);
}
}

const mkldnn::engine engine = CpuEngine::Get()->get_engine();
if (param.global_pool) {
pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0;
Expand Down Expand Up @@ -223,6 +233,16 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam &param,
int pad_l_ = param.pad[1], pad_r_ = param.pad[1];
int stride_h_ = param.stride[0], stride_w_ = param.stride[1];

if (param.pooling_convention == pool_enum::kFull) {
if ((data_md.data.dims[2] + pad_t_ + pad_b_ - kernel_h_) % stride_h_ != 0) {
pad_b_ += stride_h_ - ((data_md.data.dims[2] + pad_t_ + pad_b_ - kernel_h_) % stride_h_);
}

if ((data_md.data.dims[3] + pad_l_ + pad_r_ - kernel_w_) % stride_w_ != 0) {
pad_r_ += stride_w_ - ((data_md.data.dims[3] + pad_l_ + pad_r_ - kernel_w_) % stride_w_);
}
}

if (param.global_pool) {
pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0;
stride_h_ = stride_w_ = 1;
Expand Down Expand Up @@ -299,6 +319,17 @@ void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam &param,
int pad_t_ = param.pad[0], pad_b_ = param.pad[0];
int pad_l_ = param.pad[1], pad_r_ = param.pad[1];
int stride_h_ = param.stride[0], stride_w_ = param.stride[1];

if (param.pooling_convention == pool_enum::kFull) {
if ((data_md.data.dims[2] + pad_t_ + pad_b_ - kernel_h_) % stride_h_ != 0) {
pad_b_ += stride_h_ - ((data_md.data.dims[2] + pad_t_ + pad_b_ - kernel_h_) % stride_h_);
}

if ((data_md.data.dims[3] + pad_l_ + pad_r_ - kernel_w_) % stride_w_ != 0) {
pad_r_ += stride_w_ - ((data_md.data.dims[3] + pad_l_ + pad_r_ - kernel_w_) % stride_w_);
}
}

if (param.global_pool) {
pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0;
stride_h_ = stride_w_ = 1;
Expand Down
29 changes: 29 additions & 0 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,35 @@ def test_3d_pooling(pool_type, p_value=2):
test_3d_pooling('lp', p_value=3)


@with_seed()
def test_pooling_full_2d():
def test_pooling_full_2d_type(pool_type):
data = (2, 2, 10, 10)
kernel = (4, 5)
pad = (1, 2)
stride = (3, 4)

convention = 'full'
ctx_list = []
sym_list = []

# o_h = ceil((10 + 1 + 1 - 4) / 3) + 1 = 4
# o_w = ceil((10 + 2 + 2 - 5) / 4) + 1 = 4
ctx_list.append({'ctx': mx.cpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
pooling_convention=convention, global_pool=True, name='pool'))

ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
pooling_convention=convention, global_pool=True, name='pool'))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't seem you test your code. Once global_pool is true, all paddings are set to 0.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The symbol defined on line 938 and 942 is exactly the same.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check your test code with and without your fix to make sure that your test can trigger the bug.


check_consistency(sym_list, ctx_list)

test_pooling_full_2d_type('max')
test_pooling_full_2d_type('avg')
test_pooling_full_2d_type('sum')


@with_seed()
def test_global_pooling():
def test_1d_pooling(pool_type, p_value=2):
Expand Down