diff --git a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h index 66679613d3ae..f548778c7615 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h @@ -113,20 +113,12 @@ inline bool SupportMKLDNNPooling(const PoolingParam ¶m, if (!ret) return false; - if (param.pooling_convention == pool_enum::kValid) + if (param.pooling_convention == pool_enum::kValid) { return true; - else - return false; - -// need to support pooling convention full -// https://issues.apache.org/jira/browse/MXNET-33 -#if 0 - if (((dshape[2] + 2 * param.pad[0] - param.kernel[0]) % param.stride[0] == 0) && - ((dshape[3] + 2 * param.pad[1] - param.kernel[1]) % param.stride[1] == 0)) - return true; - else - return false; -#endif + } else { + // currently, only max-pooling is supported for full convention + return param.pool_type == pool_enum::kMaxPooling; + } } inline bool MKLDNNRequireWorkspace(const PoolingParam ¶m) { diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc index 1610944304e1..18dc835c0d0b 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling.cc +++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc @@ -134,6 +134,14 @@ mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam ¶m) { } } +static inline int GetPaddingSizeFull(int x, int padl, int padr, int k, int s) { + if ((x + padl + padr - k) % s != 0) { + return (padr + s - ((x + padl + padr - k) % s)); + } else { + return padr; + } +} + mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc( const PoolingParam ¶m, const bool is_train, const memory::desc &data_md, const memory::desc &out_md) { @@ -154,11 +162,17 @@ mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc( int pad_l_ = param.pad[1], pad_r_ = param.pad[1]; int stride_h_ = param.stride[0], stride_w_ = param.stride[1]; + if (param.pooling_convention == pool_enum::kFull) { + pad_b_ = GetPaddingSizeFull(data_md.data.dims[2], pad_t_, pad_b_, kernel_h_, stride_h_); + pad_r_ = GetPaddingSizeFull(data_md.data.dims[3], pad_l_, pad_r_, kernel_w_, stride_w_); + } + const mkldnn::engine engine = CpuEngine::Get()->get_engine(); if (param.global_pool) { pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0; stride_h_ = stride_w_ = 1; } + if (pad_t_ != 0 || pad_l_ != 0) { CHECK(param.pool_type == pool_enum::kAvgPooling || param.pool_type == pool_enum::kMaxPooling) @@ -167,7 +181,6 @@ mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc( CHECK_LT(pad_t_, kernel_h_); } - const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param); mkldnn::prop_kind kind = mkldnn::prop_kind::forward_scoring; if (is_train && alg != algorithm::pooling_avg) { @@ -227,17 +240,22 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam ¶m, int pad_l_ = param.pad[1], pad_r_ = param.pad[1]; int stride_h_ = param.stride[0], stride_w_ = param.stride[1]; + if (param.pooling_convention == pool_enum::kFull) { + pad_b_ = GetPaddingSizeFull(data_md.data.dims[2], pad_t_, pad_b_, kernel_h_, stride_h_); + pad_r_ = GetPaddingSizeFull(data_md.data.dims[3], pad_l_, pad_r_, kernel_w_, stride_w_); + } + if (param.global_pool) { - pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0; - stride_h_ = stride_w_ = 1; + pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0; + stride_h_ = stride_w_ = 1; } if (pad_t_ != 0 || pad_l_ != 0) { - CHECK(param.pool_type == pool_enum::kAvgPooling || - param.pool_type == pool_enum::kMaxPooling) - << "Padding implemented only for average and max pooling."; - CHECK_LT(pad_l_, kernel_w_); - CHECK_LT(pad_t_, kernel_h_); + CHECK(param.pool_type == pool_enum::kAvgPooling || + param.pool_type == pool_enum::kMaxPooling) + << "Padding implemented only for average and max pooling."; + CHECK_LT(pad_l_, kernel_w_); + CHECK_LT(pad_t_, kernel_h_); } const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param); @@ -353,6 +371,12 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam ¶m, int pad_t_ = param.pad[0], pad_b_ = param.pad[0]; int pad_l_ = param.pad[1], pad_r_ = param.pad[1]; int stride_h_ = param.stride[0], stride_w_ = param.stride[1]; + + if (param.pooling_convention == pool_enum::kFull) { + pad_b_ = GetPaddingSizeFull(data_md.data.dims[2], pad_t_, pad_b_, kernel_h_, stride_h_); + pad_r_ = GetPaddingSizeFull(data_md.data.dims[3], pad_l_, pad_r_, kernel_w_, stride_w_); + } + if (param.global_pool) { pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0; stride_h_ = stride_w_ = 1; diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index e3299681814d..8054937a84c6 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -996,6 +996,35 @@ def test_3d_pooling(pool_type, p_value=2, count_include_pad=True): test_3d_pooling('lp', p_value=3) +@with_seed() +def test_pooling_full_2d(): + def test_pooling_full_2d_type(pool_type): + data = (2, 2, 10, 10) + kernel = (4, 5) + pad = (1, 2) + stride = (3, 4) + + convention = 'full' + ctx_list = [] + sym_list = [] + + # o_h = ceil((10 + 1 + 1 - 4) / 3) + 1 = 4 + # o_w = ceil((10 + 2 + 2 - 5) / 4) + 1 = 4 + ctx_list.append({'ctx': mx.cpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}}) + sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, + pooling_convention=convention, global_pool=False, name='pool')) + + ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}}) + sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, + pooling_convention=convention, global_pool=False, name='pool')) + + check_consistency(sym_list, ctx_list) + + test_pooling_full_2d_type('max') + test_pooling_full_2d_type('avg') + test_pooling_full_2d_type('sum') + + @with_seed() def test_global_pooling(): def test_1d_pooling(pool_type, p_value=2):