From acb507b9954744e4d76b7cd09cd1512386e4d053 Mon Sep 17 00:00:00 2001 From: Pedro Freire Date: Fri, 15 Nov 2019 20:59:57 +0000 Subject: [PATCH 1/9] Add Deformable Convolution operation. This adds the deformable convolution operation, as described in Deformable Convolutional Networks (https://arxiv.org/abs/1703.06211). - The code is based on https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp ; the whole code was modified and refactored to remove redundancies and increase clarity, and to adapt it to torchvision. - The CPU part is a direct copy of the CUDA code; it might make sense to do follow-up adjustments in the CPU code to simplify it / optimize it, or to reuse functionality between CPU and CUDA.. - We also add tests (with a non-trivial set of parameters); they can be made more robust by randomizing the parameters and executing multiple times. --- test/test_ops.py | 179 ++++-- torchvision/csrc/DeformConv.h | 141 +++++ torchvision/csrc/cpu/DeformConv_cpu.cpp | 643 ++++++++++++++++++++ torchvision/csrc/cpu/vision_cpu.h | 23 + torchvision/csrc/cuda/DeformConv_cuda.cu | 728 +++++++++++++++++++++++ torchvision/csrc/cuda/vision_cuda.h | 23 + torchvision/csrc/vision.cpp | 2 + torchvision/ops/__init__.py | 10 +- torchvision/ops/deform_conv.py | 70 +++ 9 files changed, 1780 insertions(+), 39 deletions(-) create mode 100644 torchvision/csrc/DeformConv.h create mode 100644 torchvision/csrc/cpu/DeformConv_cpu.cpp create mode 100644 torchvision/csrc/cuda/DeformConv_cuda.cu create mode 100644 torchvision/ops/deform_conv.py diff --git a/test/test_ops.py b/test/test_ops.py index 9d4916771ab..352fbac2260 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,15 +1,18 @@ from __future__ import division +import math +import unittest + import numpy as np + import torch +from torch import Tensor from torch.autograd import gradcheck - +from torch.jit.annotations import Tuple +from torch.nn.modules.utils import _pair from torchvision import ops -from itertools import product -import unittest - -class RoIOpTester(object): +class OpTester(object): @classmethod def setUpClass(cls): cls.dtype = torch.float64 @@ -42,6 +45,14 @@ def test_backward_cuda_contiguous(self): def test_backward_cuda_non_contiguous(self): self._test_backward(device=torch.device('cuda'), contiguous=False) + def _test_forward(self, device, contiguous): + pass + + def _test_backward(self, device, contiguous): + pass + + +class RoIOpTester(OpTester): def _test_forward(self, device, contiguous): pool_size = 5 # n_channels % (pool_size ** 2) == 0 required for PS opeartions. @@ -79,7 +90,6 @@ def func(z): self.assertTrue(gradcheck(func, (x,))) self.assertTrue(gradcheck(script_func, (x,))) - return def fn(*args, **kwargs): pass @@ -98,7 +108,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar def get_script_fn(self, rois, pool_size): @torch.jit.script def script_fn(input, rois, pool_size): - # type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor + # type: (Tensor, Tensor, int) -> Tensor return ops.roi_pool(input, rois, pool_size, 1.0)[0] return lambda x: script_fn(x, rois, pool_size) @@ -137,7 +147,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar def get_script_fn(self, rois, pool_size): @torch.jit.script def script_fn(input, rois, pool_size): - # type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor + # type: (Tensor, Tensor, int) -> Tensor return ops.ps_roi_pool(input, rois, pool_size, 1.0)[0] return lambda x: script_fn(x, rois, pool_size) @@ -174,29 +184,35 @@ def get_slice(k, block): return y -def bilinear_interpolate(data, height, width, y, x): - if y < -1.0 or y > height or x < -1.0 or x > width: - return 0. +def bilinear_interpolate(data, y, x, snap_border=False): + height, width = data.shape - y = min(max(0, y), height - 1) - x = min(max(0, x), width - 1) + if snap_border: + if -1 < y <= 0: + y = 0 + elif height - 1 <= y < height: + y = height - 1 - y_low = int(y) - y_high = min(y_low + 1, height - 1) + if -1 < x <= 0: + x = 0 + elif width - 1 <= x < width: + x = width - 1 - x_low = int(x) - x_high = min(x_low + 1, width - 1) + y_low = int(math.floor(y)) + x_low = int(math.floor(x)) + y_high = y_low + 1 + x_high = x_low + 1 wy_h = y - y_low - wy_l = 1 - wy_h - wx_h = x - x_low + wy_l = 1 - wy_h wx_l = 1 - wx_h val = 0 - for wx, x in zip((wx_l, wx_h), (x_low, x_high)): - for wy, y in zip((wy_l, wy_h), (y_low, y_high)): - val += wx * wy * data[y * width + x] + for wx, xp in zip((wx_l, wx_h), (x_low, x_high)): + for wy, yp in zip((wy_l, wy_h), (y_low, y_high)): + if 0 <= yp < height and 0 <= xp < width: + val += wx * wy * data[yp, xp] return val @@ -208,7 +224,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar def get_script_fn(self, rois, pool_size): @torch.jit.script def script_fn(input, rois, pool_size): - # type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor + # type: (Tensor, Tensor, int) -> Tensor return ops.roi_align(input, rois, pool_size, 1.0)[0] return lambda x: script_fn(x, rois, pool_size) @@ -242,12 +258,7 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_r y = start_h + (iy + 0.5) * bin_h / grid_h for ix in range(0, grid_w): x = start_w + (ix + 0.5) * bin_w / grid_w - val += bilinear_interpolate( - in_data[batch_idx, channel, :, :].flatten(), - in_data.size(-2), - in_data.size(-1), - y, x - ) + val += bilinear_interpolate(in_data[batch_idx, channel, :, :], y, x, snap_border=True) val /= grid_h * grid_w out_data[r, channel, i, j] = val @@ -262,7 +273,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar def get_script_fn(self, rois, pool_size): @torch.jit.script def script_fn(input, rois, pool_size): - # type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor + # type: (Tensor, Tensor, int) -> Tensor return ops.ps_roi_align(input, rois, pool_size, 1.0)[0] return lambda x: script_fn(x, rois, pool_size) @@ -298,12 +309,7 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, device, spatial_scale=1, y = start_h + (iy + 0.5) * bin_h / grid_h for ix in range(0, grid_w): x = start_w + (ix + 0.5) * bin_w / grid_w - val += bilinear_interpolate( - in_data[batch_idx, c_in, :, :].flatten(), - in_data.size(-2), - in_data.size(-1), - y, x - ) + val += bilinear_interpolate(in_data[batch_idx, c_in, :, :], y, x, snap_border=True) val /= grid_h * grid_w out_data[r, c_out, i, j] = val @@ -376,5 +382,106 @@ def test_new_empty_tensor(self): assert out.dtype == input.dtype +class DeformConvTester(OpTester, unittest.TestCase): + def expected_fn(self, x, offsets, weights, *args, stride=1, pad=0, dilation=1): + stride_h, stride_w = _pair(stride) + pad_h, pad_w = _pair(pad) + dil_h, dil_w = _pair(dilation) + weights_h, weights_w = weights.shape[-2:] + + n_batches, n_in_channels, in_h, in_w = x.shape + n_out_channels = weights.shape[0] + + out_h = (in_h + 2 * pad_h - (dil_h * (weights_h - 1) + 1)) // stride_h + 1 + out_w = (in_w + 2 * pad_w - (dil_w * (weights_w - 1) + 1)) // stride_w + 1 + + n_offset_grps = offsets.shape[1] // (2 * weights_h * weights_w) + in_c_per_offset_grp = n_in_channels // n_offset_grps + + n_weight_grps = n_in_channels // weights.shape[1] + in_c_per_weight_grp = weights.shape[1] + out_c_per_weight_grp = n_out_channels // n_weight_grps + + out = torch.zeros(n_batches, n_out_channels, out_h, out_w, device=x.device, dtype=x.dtype) + for b in range(n_batches): + for c_out in range(n_out_channels): + for i in range(out_h): + for j in range(out_w): + for di in range(weights_h): + for dj in range(weights_w): + for c in range(in_c_per_weight_grp): + weight_grp = c_out // out_c_per_weight_grp + c_in = weight_grp * in_c_per_weight_grp + c + + offset_grp = c_in // in_c_per_offset_grp + offset_idx = 2 * (offset_grp * (weights_h * weights_w) + di * weights_w + dj) + + pi = stride_h * i - pad_h + dil_h * di + offsets[b, offset_idx, i, j] + pj = stride_w * j - pad_w + dil_w * dj + offsets[b, offset_idx + 1, i, j] + + out[b, c_out, i, j] += (weights[c_out, c, di, dj] * + bilinear_interpolate(x[b, c_in, :, :], pi, pj)) + return out + + def get_fn_args(self, device, contiguous): + batch_sz = 1 + n_in_channels = 6 + n_out_channels = 2 + n_weight_grps = 2 + n_offset_grps = 3 + + stride = (2, 1) + pad = (1, 0) + dilation = (2, 1) + + stride_h, stride_w = stride + pad_h, pad_w = pad + dil_h, dil_w = dilation + weight_h, weight_w = (3, 2) + in_h, in_w = (5, 4) + + out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1 + out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1 + + x = torch.rand(batch_sz, n_in_channels, in_h, in_w, device=device, dtype=self.dtype, requires_grad=True) + + offset = torch.randn(batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w, + device=device, dtype=self.dtype, requires_grad=True) + + weight = torch.randn(n_out_channels, n_in_channels // n_weight_grps, weight_h, weight_w, + device=device, dtype=self.dtype, requires_grad=True) + + if not contiguous: + x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2) + offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1) + weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0) + + return x, offset, weight, stride, pad, dilation + + def _test_forward(self, device, contiguous): + x, offset, weight, stride, pad, dilation = self.get_fn_args(device, contiguous) + + res = ops.DeformConv(stride=stride, pad=pad, dilation=dilation)(x, offset, weight) + expected = self.expected_fn(x, offset, weight, stride=stride, pad=pad, dilation=dilation) + + self.assertTrue(torch.allclose(res, expected), '\nres:\n{}\nexpected:\n{}'.format(x, res, expected)) + + def _test_backward(self, device, contiguous): + x, offset, weight, stride, pad, dilation = self.get_fn_args(device, contiguous) + + def func(x_, offset_, weight_): + return ops.deform_conv(x_, offset_, weight_, stride=stride, pad=pad, dilation=dilation) + + gradcheck(func, (x, offset, weight), nondet_tol=1e-5) + + @torch.jit.script + def script_func(x_, offset_, weight_, stride_, pad_, dilation_): + # type: (Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor + return ops.deform_conv(x_, offset_, weight_, stride=stride_, pad=pad_, dilation=dilation_) + + gradcheck(lambda z, off, wei: script_func(z, off, wei, stride, pad, dilation), + (x, offset, weight), nondet_tol=1e-5) + + if __name__ == '__main__': unittest.main() diff --git a/torchvision/csrc/DeformConv.h b/torchvision/csrc/DeformConv.h new file mode 100644 index 00000000000..1f04259e9c8 --- /dev/null +++ b/torchvision/csrc/DeformConv.h @@ -0,0 +1,141 @@ +#pragma once + +#include "cpu/vision_cpu.h" + +#ifdef WITH_CUDA +#include "cuda/vision_cuda.h" +#endif + +at::Tensor DCN_forward( + const Tensor& input, + const Tensor& offset, + const Tensor& weights, + const std::pair& stride, + const std::pair& pad, + const std::pair& dilation, + const int groups, + const int deformable_groups, + const int n_parallel_imgs) { + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return DCN_forward_cuda(input.contiguous(), offset.contiguous(), weights.contiguous(), stride, pad, + dilation, groups, deformable_groups, n_parallel_imgs); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + return DCN_forward_cpu(input.contiguous(), offset.contiguous(), weights.contiguous(), stride, pad, + dilation, groups, deformable_groups, n_parallel_imgs); +} + +std::tuple DCN_backward( + const at::Tensor& grad, + const Tensor& input, + const Tensor& offset, + const Tensor& weights, + const std::pair& stride, + const std::pair& pad, + const std::pair& dilation, + const int groups, + const int deformable_groups, + const int n_parallel_imgs) { + if (grad.type().is_cuda()) { +#ifdef WITH_CUDA + return DCN_backward_cuda(grad.contiguous(), input.contiguous(), offset.contiguous(), weights.contiguous(), stride, pad, + dilation, groups, deformable_groups, n_parallel_imgs); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + return DCN_backward_cpu(grad.contiguous(), input.contiguous(), offset.contiguous(), weights.contiguous(), stride, pad, + dilation, groups, deformable_groups, n_parallel_imgs); +} + +using namespace at; +using torch::Tensor; +using torch::autograd::AutogradContext; +using torch::autograd::Variable; +using torch::autograd::variable_list; + +class DeformConvFunction : public torch::autograd::Function { + public: + static variable_list forward( + AutogradContext* ctx, + Variable input, + Variable offset, + Variable weights, + int64_t stride_h, int64_t stride_w, + int64_t pad_h, int64_t pad_w, + int64_t dilation_h, int64_t dilation_w, + int64_t groups, + int64_t deformable_groups, + int64_t n_parallel_imgs) { + auto output = DCN_forward(input, offset, weights, + {stride_h, stride_w}, + {pad_h, pad_w}, + {dilation_h, dilation_w}, + groups, deformable_groups, n_parallel_imgs); + + ctx->save_for_backward({input, offset, weights}); + ctx->saved_data["stride_h"] = stride_h; + ctx->saved_data["stride_w"] = stride_w; + ctx->saved_data["pad_h"] = pad_h; + ctx->saved_data["pad_w"] = pad_w; + ctx->saved_data["dilation_h"] = dilation_h; + ctx->saved_data["dilation_w"] = dilation_w; + ctx->saved_data["groups"] = groups; + ctx->saved_data["deformable_groups"] = deformable_groups; + ctx->saved_data["n_parallel_imgs"] = n_parallel_imgs; + + return {output,}; + } + + static variable_list backward( + AutogradContext* ctx, + variable_list grad_output) { + auto saved = ctx->get_saved_variables(); + auto input = saved[0]; + auto offset = saved[1]; + auto weight = saved[2]; + + auto stride_h = ctx->saved_data["stride_h"].toInt(); + auto stride_w = ctx->saved_data["stride_w"].toInt(); + auto pad_h = ctx->saved_data["pad_h"].toInt(); + auto pad_w = ctx->saved_data["pad_w"].toInt(); + auto dilation_h = ctx->saved_data["dilation_h"].toInt(); + auto dilation_w = ctx->saved_data["dilation_w"].toInt(); + auto groups = ctx->saved_data["groups"].toInt(); + auto deformable_groups = ctx->saved_data["deformable_groups"].toInt(); + auto n_parallel_imgs = ctx->saved_data["n_parallel_imgs"].toInt(); + + auto grads = DCN_backward(grad_output[0], + input, offset, weight, + {stride_h, stride_w}, + {pad_h, pad_w}, + {dilation_h, dilation_w}, + groups, deformable_groups, n_parallel_imgs); + auto grad_input = std::get<0>(grads); + auto grad_offset = std::get<1>(grads); + auto grad_weight = std::get<2>(grads); + + return {grad_input, grad_offset, grad_weight, + Variable(), Variable(), Variable(), + Variable(), Variable(), Variable(), + Variable(), Variable(), Variable(),}; + } +}; + +Tensor deform_conv( + const Tensor& input, + const Tensor& offset, + const Tensor& weights, + int64_t stride_h, int64_t stride_w, + int64_t pad_h, int64_t pad_w, + int64_t dilation_h, int64_t dilation_w, + int64_t groups, + int64_t deformable_groups, + int64_t n_parallel_imgs) { + auto result = DeformConvFunction::apply(input, offset, weights, stride_h, stride_w, pad_h, pad_w, + dilation_h, dilation_w, groups, deformable_groups, n_parallel_imgs); + return result[0]; +} diff --git a/torchvision/csrc/cpu/DeformConv_cpu.cpp b/torchvision/csrc/cpu/DeformConv_cpu.cpp new file mode 100644 index 00000000000..eb47a652c51 --- /dev/null +++ b/torchvision/csrc/cpu/DeformConv_cpu.cpp @@ -0,0 +1,643 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu + +// modified from https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp + + +#include +#include +#include + +#include + + +using namespace at; + +template +static scalar_t bilinear_interpolate(const scalar_t *in, const int height, const int width, scalar_t h, scalar_t w) { + if (h <= -1 || height <= h || w <= -1 || width <= w) { + return 0; + } + + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = in[h_low * width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = in[h_low * width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = in[h_high * width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = in[h_high * width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +static void deformable_im2col_kernel(const int n, const scalar_t* input, const scalar_t* offset, + const int height, const int width, const int weight_h, const int weight_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dil_h, const int dil_w, + const int batch_sz, const int n_in_channels, const int n_offset_grps, + const int out_h, const int out_w, + scalar_t* columns) { + for(int index = 0; index != n; ++index) { + const int out_x = index % out_w; + const int out_y = (index / out_w) % out_h; + const int out_b = (index / (out_w * out_h)) % batch_sz; + const int in_c = index / (out_w * out_h * batch_sz); + const int out_c = in_c * weight_h * weight_w; + + int c_per_offset_grp = n_in_channels / n_offset_grps; + const int grp_idx = in_c / c_per_offset_grp; + + auto columns_ptr = columns + (out_c * (batch_sz * out_h * out_w) + + out_b * (out_h * out_w) + + out_y * out_w + + out_x); + + auto input_ptr = input + (out_b * (n_in_channels * height * width) + + in_c * (height * width)); + + auto offset_ptr = offset + (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h * out_w; + + for (int i = 0; i < weight_h; ++i) { + for (int j = 0; j < weight_w; ++j) { + const int offset_idx = 2 * (i * weight_w + j); + const scalar_t offset_h = offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x]; + const scalar_t offset_w = offset_ptr[(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x]; + const scalar_t y = (out_y * stride_h - pad_h) + i * dil_h + offset_h; + const scalar_t x = (out_x * stride_w - pad_w) + j * dil_w + offset_w; + *columns_ptr = bilinear_interpolate(input_ptr, height, width, y, x); + columns_ptr += batch_sz * out_h * out_w; + } + } + } +} + +static void deformable_im2col( + const at::Tensor input, const at::Tensor data_offset, int n_in_channels, + int height, int width, + int weight_h, int weight_w, + int pad_h, int pad_w, + int stride_h, int stride_w, + int dil_h, int dil_w, + int out_h, int out_w, + int parallel_imgs, int deformable_group, at::Tensor data_col) { + int num_kernels = n_in_channels * out_h * out_w * parallel_imgs; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "deformable_im2col", ([&] { + deformable_im2col_kernel( + num_kernels, + input.data_ptr(), + data_offset.data_ptr(), + height, width, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, dil_h, dil_w, + parallel_imgs, n_in_channels, deformable_group, + out_h, out_w, + data_col.data_ptr()); + })); +} + +at::Tensor DCN_forward_cpu( + const at::Tensor& input_param, + const at::Tensor& offset_param, + const at::Tensor& weight_param, + std::pair stride, + std::pair pad, + std::pair dilation, + int n_weight_grps, int n_offset_grps, int n_parallel_imgs) { + at::Tensor input = input_param; + at::Tensor offset = offset_param; + at::Tensor weight = weight_param; + + TORCH_CHECK(input.ndimension() == 4); + TORCH_CHECK(offset.ndimension() == 4); + TORCH_CHECK(weight.ndimension() == 4); + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(offset.is_contiguous()); + TORCH_CHECK(weight.is_contiguous()); + TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); + + int batch_sz = input.size(0); + int n_in_channels = input.size(1); + int in_h = input.size(2); + int in_w = input.size(3); + + n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); + + // Unpack shapes and args + int out_channels = weight.size(0); + int weight_h = weight.size(2); + int weight_w = weight.size(3); + + int stride_h = stride.first; + int stride_w = stride.second; + + int pad_h = pad.first; + int pad_w = pad.second; + + int dil_h = dilation.first; + int dil_w = dilation.second; + + int ker_h = dil_h * (weight_h - 1) + 1; + int ker_w = dil_w * (weight_w - 1) + 1; + int out_h = ((in_h + 2*pad_h - ker_h) / stride_h) + 1; + int out_w = ((in_w + 2*pad_w - ker_w) / stride_w) + 1; + + + TORCH_CHECK(batch_sz % n_parallel_imgs == 0); + + TORCH_CHECK(weight_h > 0 && weight_w > 0, "weight_h: ", weight_w, " weight_w: ", weight_h); + TORCH_CHECK(stride_h > 0 && stride_w > 0, "stride_h: ", stride_w, " stride_w: ", stride_h); + TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_w, " pad_w: ", pad_h); + TORCH_CHECK(dil_h > 0 && dil_w > 0, "dil_h: ", dil_w, " dil_w: ", dil_h); + + TORCH_CHECK(weight.size(1) * n_weight_grps == input.size(1)); + TORCH_CHECK(weight.size(0) % n_weight_grps == 0); + TORCH_CHECK(input.size(1) % n_offset_grps == 0); + + TORCH_CHECK((offset.size(0) == input.size(0)), "invalid batch size of offset"); + TORCH_CHECK((offset.size(1) == n_offset_grps * 2 * weight_h * weight_w), + "got: ", offset.size(1), " expected: ", n_offset_grps * 2 * weight_h * weight_w); + TORCH_CHECK((offset.size(2) == out_h && offset.size(3) == out_w), + "offset output dims: (", offset.size(2), ", ", offset.size(3), ") - ", + "computed output dims: (", out_h, ", ", out_w, ")"); + TORCH_CHECK(out_h > 0 && out_w > 0, "Calculated output size too small - out_h: ", out_h, " out_w: ", out_w); + + + auto out = at::zeros({batch_sz, out_channels, out_h, out_w}, input.options()); + // Separate batches into blocks + out = out.view({batch_sz / n_parallel_imgs, n_parallel_imgs, out_channels, out_h, out_w}); + input = input.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + offset = offset.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + at::Tensor out_buf = at::zeros({batch_sz / n_parallel_imgs, out_channels, n_parallel_imgs * out_h, out_w}, out.options()); + + // Separate channels into convolution groups + out_buf = out_buf.view({out_buf.size(0), n_weight_grps, out_buf.size(1) / n_weight_grps, out_buf.size(2), out_buf.size(3)}); + weight = weight.view({n_weight_grps, weight.size(0) / n_weight_grps, weight.size(1), weight.size(2), weight.size(3)}); + + // Sample points and perform convolution + auto columns = at::zeros({n_in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w}, input.options()); + for (int b = 0; b < batch_sz / n_parallel_imgs; b++) { + deformable_im2col(input[b], offset[b], n_in_channels, in_h, + in_w, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, dil_h, + dil_w, out_h, out_w, n_parallel_imgs, n_offset_grps, columns); + + columns = columns.view({n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + for (int g = 0; g < n_weight_grps; g++) { + out_buf[b][g] = out_buf[b][g].flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(out_buf[b][g]); + } + } + + out_buf = out_buf.view({batch_sz / n_parallel_imgs, out_channels, n_parallel_imgs, out_h, out_w}); + out_buf.transpose_(1, 2); + out.copy_(out_buf); + out = out.view({batch_sz, out_channels, out_h, out_w}); + + return out; +} + + +template +static void deformable_col2im_kernel( + const int n, const scalar_t *col, const scalar_t *offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int batch_sz, const int n_offset_grps, + const int out_h, const int out_w, + scalar_t *grad_im) { + for(int index = 0; index != n; ++index) { + const int j = (index / (out_w * out_h * batch_sz)) % kernel_w; + const int i = (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h; + const int c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h); + + int c_per_offset_grp = channels / n_offset_grps; + const int offset_grp = c / c_per_offset_grp; + + int out_x = index % out_w; + int out_y = (index / out_w) % out_h; + int b = (index / (out_w * out_h)) % batch_sz; + + auto offset_ptr = offset + (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * out_h * out_w; + const int offset_h_ptr = ((2 * (i * kernel_w + j)) * out_h + out_y) * out_w + out_x; + const int offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * out_h + out_y) * out_w + out_x; + const scalar_t offset_h = offset_ptr[offset_h_ptr]; + const scalar_t offset_w = offset_ptr[offset_w_ptr]; + const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; + const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; + + for (int dy = -1; dy <= 1; dy++) { + for (int dx = -1; dx <= 1; dx++) { + int yp = int(y) + dy; + int xp = int(x) + dx; + if (0 <= yp && yp < height && + 0 <= xp && xp < width && + abs(y - yp) < 1 && + abs(x - xp) < 1) { + int grad_pos = ((b * channels + c) * height + yp) * width + xp; + scalar_t weight = (1 - abs(y - yp)) * (1 - abs(x - xp)); + grad_im[grad_pos] += weight * col[index]; + } + } + } + } +} + +static void compute_grad_input( + const at::Tensor columns, const at::Tensor offset, const int channels, + const int height, const int width, const int weight_h, + const int weight_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int n_offset_grps, + at::Tensor grad_im) { + int out_h = (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; + int out_w = (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * weight_h * weight_w * out_h * out_w * parallel_imgs; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + columns.scalar_type(), "deformable_col2im", ([&] { + deformable_col2im_kernel( + num_kernels, + columns.data_ptr(), + offset.data_ptr(), + channels, height, width, + weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + parallel_imgs, n_offset_grps, out_h, out_w, + grad_im.data_ptr()); + })); +} + + +template +static scalar_t get_coordinate_weight(const scalar_t *im_data, const int height, const int width, scalar_t y, scalar_t x, bool is_y_direction) { + int y_l = floor(y); + int x_l = floor(x); + int y_h = y_l + 1; + int x_h = x_l + 1; + + bool valid_y_l = 0 <= y_l && y_l < height; + bool valid_y_h = 0 <= y_h && y_h < height; + bool valid_x_l = 0 <= x_l && x_l < width; + bool valid_x_h = 0 <= x_h && x_h < width; + + scalar_t zero = 0; + scalar_t v_yx = (valid_y_l && valid_x_l) ? im_data[y_l * width + x_l] : zero; + scalar_t v_yX = (valid_y_l && valid_x_h) ? im_data[y_l * width + x_h] : zero; + scalar_t v_Yx = (valid_y_h && valid_x_l) ? im_data[y_h * width + x_l] : zero; + scalar_t v_YX = (valid_y_h && valid_x_h) ? im_data[y_h * width + x_h] : zero; + + if (is_y_direction) { + scalar_t dx = x - x_l; + return dx * (v_YX - v_yX) + (1 - dx) * (v_Yx - v_yx); + } else { + scalar_t dy = y - y_l; + return dy * (v_YX - v_Yx) + (1 - dy) * (v_yX - v_yx); + } +} + + +template +static void deformable_col2im_coord_kernel(const int n, const scalar_t *col, + const scalar_t *im, const scalar_t *offset, + const int channels, const int height, const int width, + const int weight_h, const int weight_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int batch_sz, const int offset_channels, const int n_offset_grps, + const int out_h, const int out_w, scalar_t *grad_offset) { + for(int index = 0; index != n; ++index) { + scalar_t val = 0; + int w = index % out_w; + int h = (index / out_w) % out_h; + int c = (index / (out_w * out_h)) % offset_channels; + int b = index / (out_w * out_h * offset_channels); + + const int offset_grp = c / (2 * weight_h * weight_w); + const int col_step = weight_h * weight_w; + + int c_per_offset_grp = channels / n_offset_grps; + + auto col_ptr = col + offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * out_w * out_h; + auto im_ptr = im + (b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width; + auto offset_ptr = offset + (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * out_h * out_w; + + const int offset_c = c - offset_grp * 2 * weight_h * weight_w; + const int is_y_direction = offset_c % 2 == 0; + + const int c_bound = c_per_offset_grp * weight_h * weight_w; + for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) { + const int col_pos = (((col_c * batch_sz + b) * out_h) + h) * out_w + w; + + int out_x = col_pos % out_w; + int out_y = (col_pos / out_w) % out_h; + int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w; + int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h; + + const int offset_h_idx = (((2 * (i * weight_w + j)) * out_h + out_y) * out_w + out_x); + const int offset_w_idx = (((2 * (i * weight_w + j) + 1) * out_h + out_y) * out_w + out_x); + const scalar_t offset_h = offset_ptr[offset_h_idx]; + const scalar_t offset_w = offset_ptr[offset_w_idx]; + + scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; + scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; + + const scalar_t weight = get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction); + val += weight * col_ptr[col_pos]; + im_ptr += height * width; + } + + grad_offset[index] = val; + } +} + +static void compute_grad_offset( + const at::Tensor columns, const at::Tensor input, const at::Tensor offset, + const int channels, const int height, const int width, const int weight_h, + const int weight_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int parallel_imgs, const int n_offset_grps, at::Tensor grad_offset) { + int out_h = (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; + int out_w = (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; + int num_kernels = out_h * out_w * 2 * weight_h * weight_w * n_offset_grps * parallel_imgs; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + columns.scalar_type(), "deformable_col2im_coord", ([&] { + deformable_col2im_coord_kernel( + num_kernels, + columns.data_ptr(), + input.data_ptr(), + offset.data_ptr(), + channels, height, width, weight_h, + weight_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, + parallel_imgs, 2 * weight_h * weight_w * n_offset_grps, n_offset_grps, + out_h, out_w, + grad_offset.data_ptr()); + })); +} + + +static std::tuple deform_conv_backward_input_cpu( + at::Tensor input, at::Tensor offset, at::Tensor weight, + at::Tensor grad_out, + std::pair stride, + std::pair pad, + std::pair dilation, + int n_weight_grps, int n_offset_grps, int n_parallel_imgs) { + + int batch_sz = input.size(0); + int n_in_channels = input.size(1); + int in_h = input.size(2); + int in_w = input.size(3); + + n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); + + long n_out_channels = weight.size(0); + int weight_h = weight.size(2); + int weight_w = weight.size(3); + + int stride_h = stride.first; + int stride_w = stride.second; + + int pad_h = pad.first; + int pad_w = pad.second; + + int dil_h = dilation.first; + int dil_w = dilation.second; + + long out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) / stride_h + 1; + long out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) / stride_w + 1; + + auto grad_input = at::zeros_like(input); + auto grad_offset = at::zeros_like(offset); + auto columns = at::zeros({n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, input.options()); + + // Separate into blocks + grad_input = grad_input.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + input = input.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + grad_offset = grad_offset.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + offset = offset.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + + grad_out = grad_out.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_out_channels, out_h, out_w}); + grad_out.transpose_(1, 2); + grad_out = grad_out.view( + {grad_out.size(0), n_weight_grps, grad_out.size(1) / n_weight_grps, + grad_out.size(2), grad_out.size(3), grad_out.size(4)}); + + for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { + // Separate into weight groups + columns = columns.view({n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + weight = weight.view({n_weight_grps, weight.size(0) / n_weight_grps, weight.size(1), weight.size(2), weight.size(3)}); + for (int g = 0; g < n_weight_grps; g++) { + columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1)); + } + columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + + compute_grad_offset(columns, input[elt], offset[elt], n_in_channels, + in_h, in_w, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, + dil_h, dil_w, n_parallel_imgs, n_offset_grps, + grad_offset[elt]); + + compute_grad_input(columns, offset[elt], n_in_channels, in_h, + in_w, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, dil_h, + dil_w, n_parallel_imgs, n_offset_grps, grad_input[elt]); + } + + grad_out = grad_out.view( + {grad_out.size(0), grad_out.size(1) * grad_out.size(2), + grad_out.size(3), grad_out.size(4), grad_out.size(5)}); + grad_out.transpose_(1, 2); + grad_out = grad_out.view({batch_sz, n_out_channels, out_h, out_w}); + + grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w}); + input = input.view({batch_sz, n_in_channels, in_h, in_w}); + grad_offset = grad_offset.view({batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + offset = offset.view({batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + + return {grad_input, grad_offset}; +} + + + +static at::Tensor deform_conv_backward_parameters_cpu( + at::Tensor input, at::Tensor offset, at::Tensor weight, + at::Tensor grad_out, + std::pair stride, + std::pair pad, + std::pair dilation, + int n_weight_grps, int n_offset_grps, int n_parallel_imgs) { + + int batch_sz = input.size(0); + int n_in_channels = input.size(1); + int in_h = input.size(2); + int in_w = input.size(3); + + n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); + + long n_out_channels = weight.size(0); + int weight_h = weight.size(2); + int weight_w = weight.size(3); + + int stride_h = stride.first; + int stride_w = stride.second; + + int pad_h = pad.first; + int pad_w = pad.second; + + int dil_h = dilation.first; + int dil_w = dilation.second; + + long out_h = grad_out.size(2); + long out_w = grad_out.size(3); + + auto grad_weight = at::zeros_like(weight);; + auto columns = at::zeros({n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, input.options()); + + grad_out = grad_out.view({batch_sz / n_parallel_imgs, n_parallel_imgs, + n_out_channels, out_h, out_w}); + grad_out.transpose_(1, 2); + + at::Tensor grad_out_buf = at::zeros_like(grad_out); + grad_out_buf.copy_(grad_out); + grad_out_buf = grad_out_buf.view({batch_sz / n_parallel_imgs, n_out_channels, n_parallel_imgs * out_h, out_w}); + grad_out_buf = grad_out_buf.view({grad_out_buf.size(0), n_weight_grps, grad_out_buf.size(1) / n_weight_grps, grad_out_buf.size(2), grad_out_buf.size(3)}); + + grad_out.transpose_(1, 2); + grad_out = grad_out.view({batch_sz, n_out_channels, out_h, out_w}); + + input = input.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + offset = offset.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + + grad_weight = grad_weight.view({n_weight_grps, grad_weight.size(0) / n_weight_grps, grad_weight.size(1), grad_weight.size(2), grad_weight.size(3)}); + for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { + deformable_im2col(input[elt], offset[elt], n_in_channels, in_h, + in_w, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, dil_h, + dil_w, out_h, out_w, n_parallel_imgs, n_offset_grps, columns); + + columns = columns.view({n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + for (int g = 0; g < n_weight_grps; g++) { + grad_weight[g] = grad_weight[g] + .flatten(1) + .addmm_(grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0)) + .view_as(grad_weight[g]); + } + columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + input = input.view({batch_sz, n_in_channels, in_h, in_w}); + offset = offset.view({batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + + grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), grad_weight.size(3), grad_weight.size(4)}); + return grad_weight; +} + + +std::tuple DCN_backward_cpu( + const at::Tensor& grad_out, + const at::Tensor& input, + const at::Tensor& offset, + const at::Tensor& weight, + std::pair stride, + std::pair pad, + std::pair dilation, + int groups, + int deformable_groups, + int n_parallel_imgs) { + + auto grad_input_and_offset = deform_conv_backward_input_cpu( + input, offset, weight, grad_out, + stride, pad, dilation, + groups, deformable_groups, n_parallel_imgs); + + auto grad_input = std::get<0>(grad_input_and_offset); + auto grad_offset = std::get<1>(grad_input_and_offset); + + auto grad_weight = deform_conv_backward_parameters_cpu( + input, offset, weight, grad_out, + stride, pad, dilation, + groups, deformable_groups, n_parallel_imgs); + + return {grad_input, grad_offset, grad_weight}; +} diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index d84b172ba49..9a151488f66 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -84,3 +84,26 @@ at::Tensor nms_cpu( const at::Tensor& dets, const at::Tensor& scores, const float iou_threshold); + +at::Tensor DCN_forward_cpu( + const at::Tensor& input, + const at::Tensor& offset, + const at::Tensor& weights, + std::pair stride, + std::pair pad, + std::pair dilation, + int groups, + int deformable_groups, + int n_parallel_imgs); + +std::tuple DCN_backward_cpu( + const at::Tensor& grad_out, + const at::Tensor& input, + const at::Tensor& offset, + const at::Tensor& weights, + std::pair stride, + std::pair pad, + std::pair dilation, + int groups, + int deformable_groups, + int n_parallel_imgs); diff --git a/torchvision/csrc/cuda/DeformConv_cuda.cu b/torchvision/csrc/cuda/DeformConv_cuda.cu new file mode 100644 index 00000000000..13a4ec21f32 --- /dev/null +++ b/torchvision/csrc/cuda/DeformConv_cuda.cu @@ -0,0 +1,728 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu + +// modified from https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp + + +#include +#include +#include +#include +#include + +#include "cuda_helpers.h" + +#include + + +using namespace at; + +const int CUDA_NUM_THREADS = 1024; +const int kMaxGridNum = 65535; + +inline int GET_BLOCKS(const int N) +{ + return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); +} + +template +__device__ scalar_t bilinear_interpolate(const scalar_t *in, const int height, const int width, scalar_t h, scalar_t w) +{ + if (h <= -1 || height <= h || w <= -1 || width <= w) { + return 0; + } + + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = in[h_low * width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = in[h_low * width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = in[h_high * width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = in[h_high * width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t* input_ptr, const scalar_t* offset_ptr, + const int height, const int width, const int weight_h, const int weight_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dil_h, const int dil_w, + const int batch_sz, const int n_in_channels, const int n_offset_grps, + const int out_h, const int out_w, + scalar_t* columns_ptr) +{ + CUDA_1D_KERNEL_LOOP(index, n) + { + const int out_x = index % out_w; + const int out_y = (index / out_w) % out_h; + const int out_b = (index / (out_w * out_h)) % batch_sz; + const int in_c = index / (out_w * out_h * batch_sz); + const int out_c = in_c * weight_h * weight_w; + + int c_per_offset_grp = n_in_channels / n_offset_grps; + const int grp_idx = in_c / c_per_offset_grp; + + columns_ptr += (out_c * (batch_sz * out_h * out_w) + + out_b * (out_h * out_w) + + out_y * out_w + + out_x); + + input_ptr += (out_b * (n_in_channels * height * width) + + in_c * (height * width)); + + offset_ptr += (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h * out_w; + + for (int i = 0; i < weight_h; ++i) { + for (int j = 0; j < weight_w; ++j) { + const int offset_idx = 2 * (i * weight_w + j); + const scalar_t offset_h = offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x]; + const scalar_t offset_w = offset_ptr[(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x]; + const scalar_t y = (out_y * stride_h - pad_h) + i * dil_h + offset_h; + const scalar_t x = (out_x * stride_w - pad_w) + j * dil_w + offset_w; + *columns_ptr = bilinear_interpolate(input_ptr, height, width, y, x); + columns_ptr += batch_sz * out_h * out_w; + } + } + } +} + +void deformable_im2col( + const at::Tensor input, const at::Tensor data_offset, int n_in_channels, + int height, int width, + int weight_h, int weight_w, + int pad_h, int pad_w, + int stride_h, int stride_w, + int dil_h, int dil_w, + int out_h, int out_w, + int parallel_imgs, int deformable_group, at::Tensor data_col) { + int num_kernels = n_in_channels * out_h * out_w * parallel_imgs; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "deformable_im2col_gpu", ([&] { + deformable_im2col_gpu_kernel<<>>( + num_kernels, + input.data_ptr(), + data_offset.data_ptr(), + height, width, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, dil_h, dil_w, + parallel_imgs, n_in_channels, deformable_group, + out_h, out_w, + data_col.data_ptr()); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_im2col: %s\n", cudaGetErrorString(err)); + } +} + +void shape_check(at::Tensor input, at::Tensor offset, + at::Tensor weight, std::pair stride, std::pair pad, + std::pair dilation, int n_weight_grps, int n_offset_grps) { + + int in_h = input.size(2); + int in_w = input.size(3); + + int weight_h = weight.size(2); + int weight_w = weight.size(3); + + int stride_h = stride.first; + int stride_w = stride.second; + + int pad_h = pad.first; + int pad_w = pad.second; + + int dil_h = dilation.first; + int dil_w = dilation.second; + + int ker_h = dil_h * (weight_h - 1) + 1; + int ker_w = dil_w * (weight_w - 1) + 1; + int out_h = ((in_h + 2*pad_h - ker_h) / stride_h) + 1; + int out_w = ((in_w + 2*pad_w - ker_w) / stride_w) + 1; + + TORCH_CHECK(weight_h > 0 && weight_w > 0); + TORCH_CHECK(stride_h > 0 && stride_w > 0); + TORCH_CHECK(dil_h > 0 && dil_w > 0, "dil_h: ", dil_w, " dil_w: ", dil_h); + TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_w, " pad_w: ", pad_h); + + TORCH_CHECK(weight.size(1) * n_weight_grps == input.size(1)); + TORCH_CHECK(weight.size(0) % n_weight_grps == 0); + TORCH_CHECK(input.size(1) % n_offset_grps == 0); + + TORCH_CHECK((offset.size(0) == input.size(0)), "invalid batch size of offset"); + TORCH_CHECK((offset.size(1) == n_offset_grps * 2 * weight_h * weight_w), + "invalid number of channels of offset"); + TORCH_CHECK((offset.size(2) == out_h && offset.size(3) == out_w), + "offset output dims: (", offset.size(2), ", ", offset.size(3), + ") - output dims: (", out_h, ", ", out_w, ")"); + + TORCH_CHECK(out_h > 0 && out_w > 0, + "Calculated output size too small - out_h: ", out_h, " out_w: ", out_w); +} + + +at::Tensor DCN_forward_cuda( + const at::Tensor& input_param, + const at::Tensor& offset_param, + const at::Tensor& weight_param, + std::pair stride, + std::pair pad, + std::pair dilation, + int n_weight_grps, int n_offset_grps, int n_parallel_imgs) { + at::Tensor input = input_param; + at::Tensor offset = offset_param; + at::Tensor weight = weight_param; + + TORCH_CHECK(input.ndimension() == 4); + TORCH_CHECK(offset.ndimension() == 4); + TORCH_CHECK(weight.ndimension() == 4); + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(offset.is_contiguous()); + TORCH_CHECK(weight.is_contiguous()); + TORCH_CHECK(input.device().is_cuda(), "input must be a CUDA tensor"); + + at::DeviceGuard guard(input.device()); + + int batch_sz = input.size(0); + int in_channels = input.size(1); + int in_h = input.size(2); + int in_w = input.size(3); + + n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); + + int out_channels = weight.size(0); + int weight_h = weight.size(2); + int weight_w = weight.size(3); + + int stride_h = stride.first; + int stride_w = stride.second; + + int pad_h = pad.first; + int pad_w = pad.second; + + int dil_h = dilation.first; + int dil_w = dilation.second; + + int ker_h = dil_h * (weight_h - 1) + 1; + int ker_w = dil_w * (weight_w - 1) + 1; + int out_h = ((in_h + 2*pad_h - ker_h) / stride_h) + 1; + int out_w = ((in_w + 2*pad_w - ker_w) / stride_w) + 1; + + + TORCH_CHECK(batch_sz % n_parallel_imgs == 0); + + TORCH_CHECK(weight_h > 0 && weight_w > 0, "weight_h: ", weight_w, " weight_w: ", weight_h); + TORCH_CHECK(stride_h > 0 && stride_w > 0, "stride_h: ", stride_w, " stride_w: ", stride_h); + TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_w, " pad_w: ", pad_h); + TORCH_CHECK(dil_h > 0 && dil_w > 0, "dil_h: ", dil_w, " dil_w: ", dil_h); + + TORCH_CHECK(weight.size(1) * n_weight_grps == input.size(1)); + TORCH_CHECK(weight.size(0) % n_weight_grps == 0); + TORCH_CHECK(input.size(1) % n_offset_grps == 0); + + TORCH_CHECK((offset.size(0) == input.size(0)), "invalid batch size of offset"); + TORCH_CHECK((offset.size(1) == n_offset_grps * 2 * weight_h * weight_w), + "got: ", offset.size(1), " expected: ", n_offset_grps * 2 * weight_h * weight_w); + TORCH_CHECK((offset.size(2) == out_h && offset.size(3) == out_w), + "offset output dims: (", offset.size(2), ", ", offset.size(3), ") - ", + "computed output dims: (", out_h, ", ", out_w, ")"); + TORCH_CHECK(out_h > 0 && out_w > 0, "Calculated output size too small - out_h: ", out_h, " out_w: ", out_w); + + + auto out = at::zeros({batch_sz, out_channels, out_h, out_w}, input.options()); + // Separate batches into blocks + out = out.view({batch_sz / n_parallel_imgs, n_parallel_imgs, out_channels, out_h, out_w}); + input = input.view({batch_sz / n_parallel_imgs, n_parallel_imgs, in_channels, in_h, in_w}); + offset = offset.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + at::Tensor out_buf = at::zeros({batch_sz / n_parallel_imgs, out_channels, n_parallel_imgs * out_h, out_w}, out.options()); + + // Separate channels into convolution groups + out_buf = out_buf.view({out_buf.size(0), n_weight_grps, out_buf.size(1) / n_weight_grps, out_buf.size(2), out_buf.size(3)}); + weight = weight.view({n_weight_grps, weight.size(0) / n_weight_grps, weight.size(1), weight.size(2), weight.size(3)}); + + // Sample points and perform convolution + auto columns = at::zeros({in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w}, input.options()); + for (int b = 0; b < batch_sz / n_parallel_imgs; b++) { + deformable_im2col(input[b], offset[b], in_channels, in_h, + in_w, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, dil_h, + dil_w, out_h, out_w, n_parallel_imgs, n_offset_grps, columns); + + columns = columns.view({n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + for (int g = 0; g < n_weight_grps; g++) { + out_buf[b][g] = out_buf[b][g].flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(out_buf[b][g]); + } + } + + out_buf = out_buf.view({batch_sz / n_parallel_imgs, out_channels, n_parallel_imgs, out_h, out_w}); + out_buf.transpose_(1, 2); + out.copy_(out_buf); + out = out.view({batch_sz, out_channels, out_h, out_w}); + + return out; +} + + +template +__global__ void deformable_col2im_gpu_kernel( + const int n, const scalar_t *col, const scalar_t *offset_ptr, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int batch_sz, const int n_offset_grps, + const int out_h, const int out_w, + scalar_t *grad_im) +{ + CUDA_1D_KERNEL_LOOP(index, n) + { + const int out_x = index % out_w; + const int out_y = (index / out_w) % out_h; + const int b = (index / (out_w * out_h)) % batch_sz; + const int j = (index / (out_w * out_h * batch_sz)) % kernel_w; + const int i = (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h; + const int c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h); + + int c_per_offset_grp = channels / n_offset_grps; + const int offset_grp = c / c_per_offset_grp; + + offset_ptr += (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * out_h * out_w; + const int offset_h_ptr = ((2 * (i * kernel_w + j)) * out_h + out_y) * out_w + out_x; + const int offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * out_h + out_y) * out_w + out_x; + const scalar_t offset_h = offset_ptr[offset_h_ptr]; + const scalar_t offset_w = offset_ptr[offset_w_ptr]; + const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; + const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; + + for (int dy = -1; dy <= 1; dy++) { + for (int dx = -1; dx <= 1; dx++) { + int yp = int(y) + dy; + int xp = int(x) + dx; + if (0 <= yp && yp < height && + 0 <= xp && xp < width && + abs(y - yp) < 1 && + abs(x - xp) < 1) { + int grad_pos = ((b * channels + c) * height + yp) * width + xp; + scalar_t weight = (1 - abs(y - yp)) * (1 - abs(x - xp)); + atomicAdd(grad_im + grad_pos, weight * col[index]); + } + } + } + } +} + +void compute_grad_input( + const at::Tensor columns, const at::Tensor offset, const int channels, + const int height, const int width, const int weight_h, + const int weight_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int n_offset_grps, + at::Tensor grad_im) { + int out_h = (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; + int out_w = (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * weight_h * weight_w * out_h * out_w * parallel_imgs; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + columns.scalar_type(), "deformable_col2im_gpu", ([&] { + deformable_col2im_gpu_kernel<<>>( + num_kernels, + columns.data_ptr(), + offset.data_ptr(), + channels, height, width, + weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + parallel_imgs, n_offset_grps, out_h, out_w, + grad_im.data_ptr()); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in compute_grad_input: %s\n", cudaGetErrorString(err)); + } +} + + +template +__device__ scalar_t get_coordinate_weight(const scalar_t *im_data, const int height, const int width, scalar_t y, scalar_t x, bool is_y_direction) { + int y_l = floor(y); + int x_l = floor(x); + int y_h = y_l + 1; + int x_h = x_l + 1; + + bool valid_y_l = 0 <= y_l && y_l < height; + bool valid_y_h = 0 <= y_h && y_h < height; + bool valid_x_l = 0 <= x_l && x_l < width; + bool valid_x_h = 0 <= x_h && x_h < width; + + scalar_t zero = 0; + scalar_t v_yx = (valid_y_l && valid_x_l) ? im_data[y_l * width + x_l] : zero; + scalar_t v_yX = (valid_y_l && valid_x_h) ? im_data[y_l * width + x_h] : zero; + scalar_t v_Yx = (valid_y_h && valid_x_l) ? im_data[y_h * width + x_l] : zero; + scalar_t v_YX = (valid_y_h && valid_x_h) ? im_data[y_h * width + x_h] : zero; + + if (is_y_direction) { + scalar_t dx = x - x_l; + return dx * (v_YX - v_yX) + (1 - dx) * (v_Yx - v_yx); + } else { + scalar_t dy = y - y_l; + return dy * (v_YX - v_Yx) + (1 - dy) * (v_yX - v_yx); + } +} + + +template +__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *col_ptr, + const scalar_t *im_ptr, const scalar_t *offset_ptr, + const int channels, const int height, const int width, + const int weight_h, const int weight_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int batch_sz, const int offset_channels, const int n_offset_grps, + const int out_h, const int out_w, scalar_t *grad_offset) +{ + CUDA_1D_KERNEL_LOOP(index, n) + { + scalar_t val = 0; + int w = index % out_w; + int h = (index / out_w) % out_h; + int c = (index / (out_w * out_h)) % offset_channels; + int b = index / (out_w * out_h * offset_channels); + + const int offset_grp = c / (2 * weight_h * weight_w); + const int col_step = weight_h * weight_w; + + int c_per_offset_grp = channels / n_offset_grps; + + col_ptr += offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * out_w * out_h; + im_ptr += (b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width; + offset_ptr += (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * out_h * out_w; + + const int offset_c = c - offset_grp * 2 * weight_h * weight_w; + const int is_y_direction = offset_c % 2 == 0; + + const int c_bound = c_per_offset_grp * weight_h * weight_w; + for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) { + const int col_pos = (((col_c * batch_sz + b) * out_h) + h) * out_w + w; + + int out_x = col_pos % out_w; + int out_y = (col_pos / out_w) % out_h; + int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w; + int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h; + + const int offset_h_ptr = (((2 * (i * weight_w + j)) * out_h + out_y) * out_w + out_x); + const int offset_w_ptr = (((2 * (i * weight_w + j) + 1) * out_h + out_y) * out_w + out_x); + const scalar_t offset_h = offset_ptr[offset_h_ptr]; + const scalar_t offset_w = offset_ptr[offset_w_ptr]; + + scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; + scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; + + const scalar_t weight = get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction); + val += weight * col_ptr[col_pos]; + im_ptr += height * width; + } + + grad_offset[index] = val; + } +} + + +void compute_grad_offset( + const at::Tensor columns, const at::Tensor input, const at::Tensor offset, + const int channels, const int height, const int width, const int weight_h, + const int weight_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int parallel_imgs, const int n_offset_grps, at::Tensor grad_offset) { + int out_h = (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; + int out_w = (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; + int num_kernels = out_h * out_w * 2 * weight_h * weight_w * n_offset_grps * parallel_imgs; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + columns.scalar_type(), "deformable_col2im_coord_gpu", ([&] { + deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, + columns.data_ptr(), + input.data_ptr(), + offset.data_ptr(), + channels, height, width, weight_h, + weight_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, + parallel_imgs, 2 * weight_h * weight_w * n_offset_grps, n_offset_grps, + out_h, out_w, + grad_offset.data_ptr()); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in compute_grad_offset: %s\n", cudaGetErrorString(err)); + } +} + + +std::tuple deform_conv_backward_input_cuda( + at::Tensor input, at::Tensor offset, at::Tensor weight, + at::Tensor grad_out, + std::pair stride, + std::pair pad, + std::pair dilation, + int n_weight_grps, int n_offset_grps, int n_parallel_imgs) { + at::DeviceGuard guard(input.device()); + + int batch_sz = input.size(0); + long n_in_channels = input.size(1); + long in_h = input.size(2); + long in_w = input.size(3); + + n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); + + long n_out_channels = weight.size(0); + int weight_h = weight.size(2); + int weight_w = weight.size(3); + + int stride_h = stride.first; + int stride_w = stride.second; + + int pad_h = pad.first; + int pad_w = pad.second; + + int dil_h = dilation.first; + int dil_w = dilation.second; + + long out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) / stride_w + 1; + long out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) / stride_h + 1; + + auto grad_input = at::zeros_like(input); + auto grad_offset = at::zeros_like(offset); + auto columns = at::zeros({n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, input.options()); + + // Separate into blocks + grad_input = grad_input.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + input = input.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + grad_offset = grad_offset.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + offset = offset.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + + grad_out = grad_out.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_out_channels, out_h, out_w}); + grad_out.transpose_(1, 2); + grad_out = grad_out.view( + {grad_out.size(0), n_weight_grps, grad_out.size(1) / n_weight_grps, + grad_out.size(2), grad_out.size(3), grad_out.size(4)}); + + for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { + // Separate into weight groups + columns = columns.view({n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + weight = weight.view({n_weight_grps, weight.size(0) / n_weight_grps, weight.size(1), weight.size(2), weight.size(3)}); + for (int g = 0; g < n_weight_grps; g++) { + columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1)); + } + columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + + compute_grad_offset(columns, input[elt], offset[elt], n_in_channels, + in_h, in_w, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, + dil_h, dil_w, n_parallel_imgs, n_offset_grps, + grad_offset[elt]); + + compute_grad_input(columns, offset[elt], n_in_channels, in_h, + in_w, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, dil_h, + dil_w, n_parallel_imgs, n_offset_grps, grad_input[elt]); + } + + grad_out = grad_out.view( + {grad_out.size(0), grad_out.size(1) * grad_out.size(2), + grad_out.size(3), grad_out.size(4), grad_out.size(5)}); + grad_out.transpose_(1, 2); + grad_out = grad_out.view({batch_sz, n_out_channels, out_h, out_w}); + + grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w}); + input = input.view({batch_sz, n_in_channels, in_h, in_w}); + grad_offset = grad_offset.view({batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + offset = offset.view({batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + + return {grad_input, grad_offset}; +} + + + +at::Tensor deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor weight, + at::Tensor grad_out, + std::pair stride, + std::pair pad, + std::pair dilation, + int n_weight_grps, int n_offset_grps, int n_parallel_imgs) { + at::DeviceGuard guard(input.device()); + + int batch_sz = input.size(0); + long n_in_channels = input.size(1); + long in_h = input.size(2); + long in_w = input.size(3); + + n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); + + long n_out_channels = weight.size(0); + int weight_h = weight.size(2); + int weight_w = weight.size(3); + + int stride_h = stride.first; + int stride_w = stride.second; + + int pad_h = pad.first; + int pad_w = pad.second; + + int dil_h = dilation.first; + int dil_w = dilation.second; + + long out_h = grad_out.size(2); + long out_w = grad_out.size(3); + + auto grad_weight = at::zeros_like(weight);; + auto columns = at::zeros({n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, input.options()); + + grad_out = grad_out.view({batch_sz / n_parallel_imgs, n_parallel_imgs, + n_out_channels, out_h, out_w}); + grad_out.transpose_(1, 2); + + at::Tensor grad_out_buf = at::zeros_like(grad_out); + grad_out_buf.copy_(grad_out); + grad_out_buf = grad_out_buf.view({batch_sz / n_parallel_imgs, n_out_channels, n_parallel_imgs * out_h, out_w}); + grad_out_buf = grad_out_buf.view({grad_out_buf.size(0), n_weight_grps, grad_out_buf.size(1) / n_weight_grps, grad_out_buf.size(2), grad_out_buf.size(3)}); + + grad_out.transpose_(1, 2); + grad_out = grad_out.view({batch_sz, n_out_channels, out_h, out_w}); + + input = input.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + offset = offset.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + + grad_weight = grad_weight.view({n_weight_grps, grad_weight.size(0) / n_weight_grps, grad_weight.size(1), grad_weight.size(2), grad_weight.size(3)}); + for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { + deformable_im2col(input[elt], offset[elt], n_in_channels, in_h, + in_w, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, dil_h, + dil_w, out_h, out_w, n_parallel_imgs, n_offset_grps, columns); + + columns = columns.view({n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + for (int g = 0; g < n_weight_grps; g++) { + grad_weight[g] = grad_weight[g] + .flatten(1) + .addmm_(grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0)) + .view_as(grad_weight[g]); + } + columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + input = input.view({batch_sz, n_in_channels, in_h, in_w}); + offset = offset.view({batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + + grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), grad_weight.size(3), grad_weight.size(4)}); + return grad_weight; +} + + +std::tuple DCN_backward_cuda( + const at::Tensor& grad_out, + const at::Tensor& input, + const at::Tensor& offset, + const at::Tensor& weight, + std::pair stride, + std::pair pad, + std::pair dilation, + int groups, + int deformable_groups, + int n_parallel_imgs) { + + auto grad_input_and_offset = deform_conv_backward_input_cuda( + input, offset, weight, grad_out, + stride, pad, dilation, + groups, deformable_groups, n_parallel_imgs); + + auto grad_input = std::get<0>(grad_input_and_offset); + auto grad_offset = std::get<1>(grad_input_and_offset); + + auto grad_weight = deform_conv_backward_parameters_cuda( + input, offset, weight, grad_out, + stride, pad, dilation, + groups, deformable_groups, n_parallel_imgs); + + return {grad_input, grad_offset, grad_weight}; +} diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index b35c4c909c1..320a5f15702 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -85,3 +85,26 @@ at::Tensor nms_cuda( const at::Tensor& dets, const at::Tensor& scores, const float iou_threshold); + +at::Tensor DCN_forward_cuda( + const at::Tensor& input, + const at::Tensor& offset, + const at::Tensor& weights, + std::pair stride, + std::pair pad, + std::pair dilation, + int groups, + int deformable_groups, + int n_parallel_imgs); + +std::tuple DCN_backward_cuda( + const at::Tensor& grad_out, + const at::Tensor& input, + const at::Tensor& offset, + const at::Tensor& weights, + std::pair stride, + std::pair pad, + std::pair dilation, + int groups, + int deformable_groups, + int n_parallel_imgs); diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index 3dc94bf9c78..b97d4c16f8f 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -11,6 +11,7 @@ #include "ROIPool.h" #include "empty_tensor_op.h" #include "nms.h" +#include "DeformConv.h" // If we are in a Windows environment, we need to define // initialization functions for the _custom_ops extension @@ -47,4 +48,5 @@ static auto registry = .op("torchvision::_new_empty_tensor_op", &new_empty_tensor) .op("torchvision::ps_roi_align", &ps_roi_align) .op("torchvision::ps_roi_pool", &ps_roi_pool) + .op("torchvision::deform_conv", &deform_conv) .op("torchvision::_cuda_version", &_cuda_version); diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 4921d2d0335..61b6e45c218 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -1,5 +1,9 @@ from .boxes import nms, box_iou +<<<<<<< HEAD from .new_empty_tensor import _new_empty_tensor +======= +from .deform_conv import deform_conv, DeformConv +>>>>>>> Add Deformable Convolution operation. from .roi_align import roi_align, RoIAlign from .roi_pool import roi_pool, RoIPool from .ps_roi_align import ps_roi_align, PSRoIAlign @@ -13,7 +17,7 @@ __all__ = [ - 'nms', 'roi_align', 'RoIAlign', 'roi_pool', 'RoIPool', '_new_empty_tensor', - 'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool', 'PSRoIPool', - 'MultiScaleRoIAlign', 'FeaturePyramidNetwork' + 'deform_conv', 'DeformConv', 'nms', 'roi_align', 'RoIAlign', 'roi_pool', + 'RoIPool', '_new_empty_tensor', 'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool', + 'PSRoIPool', 'MultiScaleRoIAlign', 'FeaturePyramidNetwork' ] diff --git a/torchvision/ops/deform_conv.py b/torchvision/ops/deform_conv.py new file mode 100644 index 00000000000..87e9a87a89a --- /dev/null +++ b/torchvision/ops/deform_conv.py @@ -0,0 +1,70 @@ +from typing import Tuple + +import torch +from torch import nn, Tensor +from torch.nn.modules.utils import _pair +from torch.jit.annotations import List + + +def deform_conv(input, offset, weight, stride=(1, 1), pad=(0, 0), dilation=(1, 1), n_parallel_imgs=64): + # type: (Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int], int) -> Tensor + """ + Performs Deformable Convolution, described in Deformable Convolutional Networks + + Arguments: + input (Tensor[batch_sz, in_channels, in_h, in_w]): input tensor + offset (Tensor[batch_sz, 2 * n_offset_grps * weight_h * weight_w, out_h, out_w]) + weight (Tensor[out_channels, in_channels // n_weight_grps, weight_h, weight_w]): + convolution weights, with n_weight_grps different connection groups + stride (int or Tuple[int, int]): distance between convolution centers + pad (int or Tuple[int, int]): height/width of padding of zeroes around each image + dilation (int or Tuple[int, int]): point distance in convolution grid + n_parallel_imgs (int): Number of images to be processed at once; does not change + behavior, only used for performance purposes + + Returns: + output (Tensor[batch_sz, out_channels, out_h, out_w]): result of convolution + """ + + stride_h, stride_w = stride + pad_h, pad_w = pad + dil_h, dil_w = dilation + weights_h, weights_w = weight.shape[-2:] + _, n_in_channels, in_h, in_w = input.shape + + n_offset_grps = offset.shape[1] // (2 * weights_h * weights_w) + n_weight_grps = n_in_channels // weight.shape[1] + + return torch.ops.torchvision.deform_conv( + input, + offset, + weight, + *stride, + *pad, + *dilation, + n_weight_grps, + n_offset_grps, + n_parallel_imgs) + + +class DeformConv(nn.Module): + """ + See deform_conv + """ + def __init__(self, stride=1, pad=0, dilation=1, n_parallel_imgs=64): + super(DeformConv, self).__init__() + self.stride = _pair(stride) + self.pad = _pair(pad) + self.dilation = _pair(dilation) + self.n_parallel_imgs = n_parallel_imgs + + def forward(self, input, offset, weight): + return deform_conv(input, offset, weight, stride=self.stride, pad=self.pad, + dilation=self.dilation, n_parallel_imgs=self.n_parallel_imgs) + + def __repr__(self): + tmpstr = self.__class__.__name__ + '(' + tmpstr += 'output_size=' + str(self.output_size) + tmpstr += ', spatial_scale=' + str(self.spatial_scale) + tmpstr += ')' + return tmpstr From f7c49841f010382311836fa90724259977c59aad Mon Sep 17 00:00:00 2001 From: Pedro Freire Date: Wed, 20 Nov 2019 10:05:45 +0000 Subject: [PATCH 2/9] Update DeformConv to be more consistent w/ Conv2d * rename some variables and arguments to match Conv2d; * add optional bias; * add weight, offset and bias as module parameters; * remove the n_parallel_imgs parameter; * Fix __repr__; * etc.. Initialization of weight and bias is the same as in Conv2d, and initialization of offsets to zero is the same as in the paper. This also includes some other small unrelated fixes/improvements. --- test/test_ops.py | 73 ++++++++----- torchvision/csrc/DeformConv.h | 91 ++++++++-------- torchvision/csrc/cpu/DeformConv_cpu.cpp | 70 +++++++----- torchvision/csrc/cpu/vision_cpu.h | 16 +-- torchvision/csrc/cuda/DeformConv_cuda.cu | 104 +++++++----------- torchvision/csrc/cuda/vision_cuda.h | 16 +-- torchvision/csrc/vision.cpp | 2 +- torchvision/ops/__init__.py | 7 +- torchvision/ops/deform_conv.py | 132 ++++++++++++++++------- 9 files changed, 282 insertions(+), 229 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 352fbac2260..2deef26b861 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -383,23 +383,23 @@ def test_new_empty_tensor(self): class DeformConvTester(OpTester, unittest.TestCase): - def expected_fn(self, x, offsets, weights, *args, stride=1, pad=0, dilation=1): + def expected_fn(self, x, weight, offset, bias, stride=1, padding=0, dilation=1): stride_h, stride_w = _pair(stride) - pad_h, pad_w = _pair(pad) + pad_h, pad_w = _pair(padding) dil_h, dil_w = _pair(dilation) - weights_h, weights_w = weights.shape[-2:] + weight_h, weight_w = weight.shape[-2:] n_batches, n_in_channels, in_h, in_w = x.shape - n_out_channels = weights.shape[0] + n_out_channels = weight.shape[0] - out_h = (in_h + 2 * pad_h - (dil_h * (weights_h - 1) + 1)) // stride_h + 1 - out_w = (in_w + 2 * pad_w - (dil_w * (weights_w - 1) + 1)) // stride_w + 1 + out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1 + out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1 - n_offset_grps = offsets.shape[1] // (2 * weights_h * weights_w) + n_offset_grps = offset.shape[1] // (2 * weight_h * weight_w) in_c_per_offset_grp = n_in_channels // n_offset_grps - n_weight_grps = n_in_channels // weights.shape[1] - in_c_per_weight_grp = weights.shape[1] + n_weight_grps = n_in_channels // weight.shape[1] + in_c_per_weight_grp = weight.shape[1] out_c_per_weight_grp = n_out_channels // n_weight_grps out = torch.zeros(n_batches, n_out_channels, out_h, out_w, device=x.device, dtype=x.dtype) @@ -407,20 +407,21 @@ def expected_fn(self, x, offsets, weights, *args, stride=1, pad=0, dilation=1): for c_out in range(n_out_channels): for i in range(out_h): for j in range(out_w): - for di in range(weights_h): - for dj in range(weights_w): + for di in range(weight_h): + for dj in range(weight_w): for c in range(in_c_per_weight_grp): weight_grp = c_out // out_c_per_weight_grp c_in = weight_grp * in_c_per_weight_grp + c offset_grp = c_in // in_c_per_offset_grp - offset_idx = 2 * (offset_grp * (weights_h * weights_w) + di * weights_w + dj) + offset_idx = 2 * (offset_grp * (weight_h * weight_w) + di * weight_w + dj) - pi = stride_h * i - pad_h + dil_h * di + offsets[b, offset_idx, i, j] - pj = stride_w * j - pad_w + dil_w * dj + offsets[b, offset_idx + 1, i, j] + pi = stride_h * i - pad_h + dil_h * di + offset[b, offset_idx, i, j] + pj = stride_w * j - pad_w + dil_w * dj + offset[b, offset_idx + 1, i, j] - out[b, c_out, i, j] += (weights[c_out, c, di, dj] * + out[b, c_out, i, j] += (weight[c_out, c, di, dj] * bilinear_interpolate(x[b, c_in, :, :], pi, pj)) + out += bias.view(1, n_out_channels, 1, 1) return out def get_fn_args(self, device, contiguous): @@ -451,36 +452,50 @@ def get_fn_args(self, device, contiguous): weight = torch.randn(n_out_channels, n_in_channels // n_weight_grps, weight_h, weight_w, device=device, dtype=self.dtype, requires_grad=True) + bias = torch.randn(n_out_channels, device=device, dtype=self.dtype, requires_grad=True) + if not contiguous: x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2) offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1) weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0) - return x, offset, weight, stride, pad, dilation + return x, weight, offset, bias, stride, pad, dilation def _test_forward(self, device, contiguous): - x, offset, weight, stride, pad, dilation = self.get_fn_args(device, contiguous) + x, _, _, _, stride, padding, dilation = self.get_fn_args(device, contiguous) + in_channels = 6 + out_channels = 2 + kernel_size = (3, 2) + groups = 2 + offset_groups = 3 + + layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, + dilation=dilation, groups=groups, offset_groups=offset_groups) + layer.offset_conv.weight.data = torch.randn_like(layer.offset_conv.weight.data) + res = layer(x) - res = ops.DeformConv(stride=stride, pad=pad, dilation=dilation)(x, offset, weight) - expected = self.expected_fn(x, offset, weight, stride=stride, pad=pad, dilation=dilation) + weight = layer.weight.data.to(device=x.device, dtype=x.dtype) + offset = layer.offset_conv.to(device=x.device, dtype=x.dtype)(x) + bias = layer.bias.data.to(device=x.device, dtype=x.dtype) + expected = self.expected_fn(x, weight, offset, bias, stride=stride, padding=padding, dilation=dilation) - self.assertTrue(torch.allclose(res, expected), '\nres:\n{}\nexpected:\n{}'.format(x, res, expected)) + self.assertTrue(torch.allclose(res, expected), '\nres:\n{}\nexpected:\n{}'.format(res, expected)) def _test_backward(self, device, contiguous): - x, offset, weight, stride, pad, dilation = self.get_fn_args(device, contiguous) + x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous) - def func(x_, offset_, weight_): - return ops.deform_conv(x_, offset_, weight_, stride=stride, pad=pad, dilation=dilation) + def func(x_, weight_, offset_, bias_): + return ops.deform_conv2d(x_, weight_, offset_, bias_, stride=stride, padding=padding, dilation=dilation) - gradcheck(func, (x, offset, weight), nondet_tol=1e-5) + gradcheck(func, (x, weight, offset, bias), nondet_tol=1e-5) @torch.jit.script - def script_func(x_, offset_, weight_, stride_, pad_, dilation_): - # type: (Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor - return ops.deform_conv(x_, offset_, weight_, stride=stride_, pad=pad_, dilation=dilation_) + def script_func(x_, weight_, offset_, bias_, stride_, pad_, dilation_): + # type: (Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor + return ops.deform_conv2d(x_, weight_, offset_, bias_, stride=stride_, padding=pad_, dilation=dilation_) - gradcheck(lambda z, off, wei: script_func(z, off, wei, stride, pad, dilation), - (x, offset, weight), nondet_tol=1e-5) + gradcheck(lambda z, wei, off, bi: script_func(z, wei, off, bi, stride, padding, dilation), + (x, weight, offset, bias), nondet_tol=1e-5) if __name__ == '__main__': diff --git a/torchvision/csrc/DeformConv.h b/torchvision/csrc/DeformConv.h index 1f04259e9c8..b057f611380 100644 --- a/torchvision/csrc/DeformConv.h +++ b/torchvision/csrc/DeformConv.h @@ -6,49 +6,48 @@ #include "cuda/vision_cuda.h" #endif -at::Tensor DCN_forward( +at::Tensor DeformConv2d_forward( const Tensor& input, + const Tensor& weight, const Tensor& offset, - const Tensor& weights, + const Tensor& bias, const std::pair& stride, - const std::pair& pad, + const std::pair& padding, const std::pair& dilation, - const int groups, - const int deformable_groups, - const int n_parallel_imgs) { + const int groups, const int offset_groups) { if (input.type().is_cuda()) { #ifdef WITH_CUDA - return DCN_forward_cuda(input.contiguous(), offset.contiguous(), weights.contiguous(), stride, pad, - dilation, groups, deformable_groups, n_parallel_imgs); + return DeformConv2d_forward_cuda(input.contiguous(), weight.contiguous(), offset.contiguous(), + bias.contiguous(), stride, padding, dilation, groups, offset_groups); #else AT_ERROR("Not compiled with GPU support"); #endif } - return DCN_forward_cpu(input.contiguous(), offset.contiguous(), weights.contiguous(), stride, pad, - dilation, groups, deformable_groups, n_parallel_imgs); + return DeformConv2d_forward_cpu(input.contiguous(), weight.contiguous(), offset.contiguous(), + bias.contiguous(), stride, padding, dilation, groups, offset_groups); } -std::tuple DCN_backward( +std::tuple DeformConv2d_backward( const at::Tensor& grad, const Tensor& input, + const Tensor& weight, const Tensor& offset, - const Tensor& weights, + const Tensor& bias, const std::pair& stride, - const std::pair& pad, + const std::pair& padding, const std::pair& dilation, const int groups, - const int deformable_groups, - const int n_parallel_imgs) { + const int offset_groups) { if (grad.type().is_cuda()) { #ifdef WITH_CUDA - return DCN_backward_cuda(grad.contiguous(), input.contiguous(), offset.contiguous(), weights.contiguous(), stride, pad, - dilation, groups, deformable_groups, n_parallel_imgs); + return DeformConv2d_backward_cuda(grad.contiguous(), input.contiguous(), weight.contiguous(), offset.contiguous(), + bias.contiguous(), stride, padding, dilation, groups, offset_groups); #else AT_ERROR("Not compiled with GPU support"); #endif } - return DCN_backward_cpu(grad.contiguous(), input.contiguous(), offset.contiguous(), weights.contiguous(), stride, pad, - dilation, groups, deformable_groups, n_parallel_imgs); + return DeformConv2d_backward_cpu(grad.contiguous(), input.contiguous(), weight.contiguous(), offset.contiguous(), + bias.contiguous(), stride, padding, dilation, groups, offset_groups); } using namespace at; @@ -57,26 +56,27 @@ using torch::autograd::AutogradContext; using torch::autograd::Variable; using torch::autograd::variable_list; -class DeformConvFunction : public torch::autograd::Function { +class DeformConv2dFunction : public torch::autograd::Function { public: static variable_list forward( AutogradContext* ctx, Variable input, + Variable weight, Variable offset, - Variable weights, + Variable bias, int64_t stride_h, int64_t stride_w, int64_t pad_h, int64_t pad_w, int64_t dilation_h, int64_t dilation_w, int64_t groups, - int64_t deformable_groups, - int64_t n_parallel_imgs) { - auto output = DCN_forward(input, offset, weights, + int64_t offset_groups) { + auto output = DeformConv2d_forward( + input, weight, offset, bias, {stride_h, stride_w}, {pad_h, pad_w}, {dilation_h, dilation_w}, - groups, deformable_groups, n_parallel_imgs); + groups, offset_groups); - ctx->save_for_backward({input, offset, weights}); + ctx->save_for_backward({input, weight, offset, bias}); ctx->saved_data["stride_h"] = stride_h; ctx->saved_data["stride_w"] = stride_w; ctx->saved_data["pad_h"] = pad_h; @@ -84,8 +84,7 @@ class DeformConvFunction : public torch::autograd::Function ctx->saved_data["dilation_h"] = dilation_h; ctx->saved_data["dilation_w"] = dilation_w; ctx->saved_data["groups"] = groups; - ctx->saved_data["deformable_groups"] = deformable_groups; - ctx->saved_data["n_parallel_imgs"] = n_parallel_imgs; + ctx->saved_data["offset_groups"] = offset_groups; return {output,}; } @@ -95,8 +94,9 @@ class DeformConvFunction : public torch::autograd::Function variable_list grad_output) { auto saved = ctx->get_saved_variables(); auto input = saved[0]; - auto offset = saved[1]; - auto weight = saved[2]; + auto weight = saved[1]; + auto offset = saved[2]; + auto bias = saved[3]; auto stride_h = ctx->saved_data["stride_h"].toInt(); auto stride_w = ctx->saved_data["stride_w"].toInt(); @@ -105,37 +105,36 @@ class DeformConvFunction : public torch::autograd::Function auto dilation_h = ctx->saved_data["dilation_h"].toInt(); auto dilation_w = ctx->saved_data["dilation_w"].toInt(); auto groups = ctx->saved_data["groups"].toInt(); - auto deformable_groups = ctx->saved_data["deformable_groups"].toInt(); - auto n_parallel_imgs = ctx->saved_data["n_parallel_imgs"].toInt(); + auto offset_groups = ctx->saved_data["offset_groups"].toInt(); - auto grads = DCN_backward(grad_output[0], - input, offset, weight, + auto grads = DeformConv2d_backward(grad_output[0], + input, weight, offset, bias, {stride_h, stride_w}, {pad_h, pad_w}, {dilation_h, dilation_w}, - groups, deformable_groups, n_parallel_imgs); + groups, offset_groups); auto grad_input = std::get<0>(grads); - auto grad_offset = std::get<1>(grads); - auto grad_weight = std::get<2>(grads); + auto grad_weight = std::get<1>(grads); + auto grad_offset = std::get<2>(grads); + auto grad_bias = std::get<3>(grads); - return {grad_input, grad_offset, grad_weight, - Variable(), Variable(), Variable(), + return {grad_input, grad_weight, grad_offset, + grad_bias, Variable(), Variable(), Variable(), Variable(), Variable(), Variable(), Variable(), Variable(),}; } }; -Tensor deform_conv( +Tensor deform_conv2d( const Tensor& input, + const Tensor& weight, const Tensor& offset, - const Tensor& weights, + const Tensor& bias, int64_t stride_h, int64_t stride_w, int64_t pad_h, int64_t pad_w, int64_t dilation_h, int64_t dilation_w, - int64_t groups, - int64_t deformable_groups, - int64_t n_parallel_imgs) { - auto result = DeformConvFunction::apply(input, offset, weights, stride_h, stride_w, pad_h, pad_w, - dilation_h, dilation_w, groups, deformable_groups, n_parallel_imgs); + int64_t groups, int64_t offset_groups) { + auto result = DeformConv2dFunction::apply(input, weight, offset, bias, stride_h, stride_w, pad_h, pad_w, + dilation_h, dilation_w, groups, offset_groups); return result[0]; } diff --git a/torchvision/csrc/cpu/DeformConv_cpu.cpp b/torchvision/csrc/cpu/DeformConv_cpu.cpp index eb47a652c51..ecbe66f6e1d 100644 --- a/torchvision/csrc/cpu/DeformConv_cpu.cpp +++ b/torchvision/csrc/cpu/DeformConv_cpu.cpp @@ -69,9 +69,11 @@ #include - using namespace at; + +const int kMaxParallelImgs = 32; + template static scalar_t bilinear_interpolate(const scalar_t *in, const int height, const int width, scalar_t h, scalar_t w) { if (h <= -1 || height <= h || w <= -1 || width <= w) { @@ -172,14 +174,24 @@ static void deformable_im2col( })); } -at::Tensor DCN_forward_cpu( +static int get_greatest_divisor_below_bound(int n, int bound) { + for(int k = bound; k > 1; --k) { + if(n % k == 0) { + return k; + } + } + return 1; +} + +at::Tensor DeformConv2d_forward_cpu( const at::Tensor& input_param, - const at::Tensor& offset_param, const at::Tensor& weight_param, + const at::Tensor& offset_param, + const at::Tensor& bias, std::pair stride, std::pair pad, std::pair dilation, - int n_weight_grps, int n_offset_grps, int n_parallel_imgs) { + int n_weight_grps, int n_offset_grps) { at::Tensor input = input_param; at::Tensor offset = offset_param; at::Tensor weight = weight_param; @@ -197,7 +209,7 @@ at::Tensor DCN_forward_cpu( int in_h = input.size(2); int in_w = input.size(3); - n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); + int n_parallel_imgs = get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); // Unpack shapes and args int out_channels = weight.size(0); @@ -219,12 +231,10 @@ at::Tensor DCN_forward_cpu( int out_w = ((in_w + 2*pad_w - ker_w) / stride_w) + 1; - TORCH_CHECK(batch_sz % n_parallel_imgs == 0); - - TORCH_CHECK(weight_h > 0 && weight_w > 0, "weight_h: ", weight_w, " weight_w: ", weight_h); - TORCH_CHECK(stride_h > 0 && stride_w > 0, "stride_h: ", stride_w, " stride_w: ", stride_h); - TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_w, " pad_w: ", pad_h); - TORCH_CHECK(dil_h > 0 && dil_w > 0, "dil_h: ", dil_w, " dil_w: ", dil_h); + TORCH_CHECK(weight_h > 0 && weight_w > 0, "weight_h: ", weight_h, " weight_w: ", weight_w); + TORCH_CHECK(stride_h > 0 && stride_w > 0, "stride_h: ", stride_h, " stride_w: ", stride_w); + TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w); + TORCH_CHECK(dil_h > 0 && dil_w > 0, "dil_h: ", dil_h, " dil_w: ", dil_w); TORCH_CHECK(weight.size(1) * n_weight_grps == input.size(1)); TORCH_CHECK(weight.size(0) % n_weight_grps == 0); @@ -240,6 +250,7 @@ at::Tensor DCN_forward_cpu( auto out = at::zeros({batch_sz, out_channels, out_h, out_w}, input.options()); + // Separate batches into blocks out = out.view({batch_sz / n_parallel_imgs, n_parallel_imgs, out_channels, out_h, out_w}); input = input.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); @@ -270,7 +281,7 @@ at::Tensor DCN_forward_cpu( out.copy_(out_buf); out = out.view({batch_sz, out_channels, out_h, out_w}); - return out; + return out + bias.view({1, out_channels, 1, 1}); } @@ -458,8 +469,8 @@ static void compute_grad_offset( } -static std::tuple deform_conv_backward_input_cpu( - at::Tensor input, at::Tensor offset, at::Tensor weight, +static std::tuple deform_conv2d_backward_input_cpu( + at::Tensor input, at::Tensor weight, at::Tensor offset, at::Tensor grad_out, std::pair stride, std::pair pad, @@ -540,8 +551,8 @@ static std::tuple deform_conv_backward_input_cpu( -static at::Tensor deform_conv_backward_parameters_cpu( - at::Tensor input, at::Tensor offset, at::Tensor weight, +static at::Tensor deform_conv2d_backward_parameters_cpu( + at::Tensor input, at::Tensor weight, at::Tensor offset, at::Tensor grad_out, std::pair stride, std::pair pad, @@ -614,30 +625,33 @@ static at::Tensor deform_conv_backward_parameters_cpu( } -std::tuple DCN_backward_cpu( +std::tuple DeformConv2d_backward_cpu( const at::Tensor& grad_out, const at::Tensor& input, - const at::Tensor& offset, const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& bias, std::pair stride, std::pair pad, std::pair dilation, - int groups, - int deformable_groups, - int n_parallel_imgs) { + int n_weight_grps, int n_offset_grps) { + const int batch_sz = input.size(0); + const int n_parallel_imgs = get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); - auto grad_input_and_offset = deform_conv_backward_input_cpu( - input, offset, weight, grad_out, + auto grad_input_and_offset = deform_conv2d_backward_input_cpu( + input, weight, offset, grad_out, stride, pad, dilation, - groups, deformable_groups, n_parallel_imgs); + n_weight_grps, n_offset_grps, n_parallel_imgs); auto grad_input = std::get<0>(grad_input_and_offset); auto grad_offset = std::get<1>(grad_input_and_offset); - auto grad_weight = deform_conv_backward_parameters_cpu( - input, offset, weight, grad_out, + auto grad_weight = deform_conv2d_backward_parameters_cpu( + input, weight, offset, grad_out, stride, pad, dilation, - groups, deformable_groups, n_parallel_imgs); + n_weight_grps, n_offset_grps, n_parallel_imgs); + + auto grad_bias = at::ones_like(bias) * grad_out.sum({0, 2, 3}); - return {grad_input, grad_offset, grad_weight}; + return {grad_input, grad_weight, grad_offset, grad_bias}; } diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index 9a151488f66..0cb03c7c782 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -85,25 +85,25 @@ at::Tensor nms_cpu( const at::Tensor& scores, const float iou_threshold); -at::Tensor DCN_forward_cpu( +at::Tensor DeformConv2d_forward_cpu( const at::Tensor& input, + const at::Tensor& weight, const at::Tensor& offset, - const at::Tensor& weights, + const at::Tensor& bias, std::pair stride, std::pair pad, std::pair dilation, int groups, - int deformable_groups, - int n_parallel_imgs); + int deformable_groups); -std::tuple DCN_backward_cpu( +std::tuple DeformConv2d_backward_cpu( const at::Tensor& grad_out, const at::Tensor& input, + const at::Tensor& weight, const at::Tensor& offset, - const at::Tensor& weights, + const at::Tensor& bias, std::pair stride, std::pair pad, std::pair dilation, int groups, - int deformable_groups, - int n_parallel_imgs); + int deformable_groups); diff --git a/torchvision/csrc/cuda/DeformConv_cuda.cu b/torchvision/csrc/cuda/DeformConv_cuda.cu index 13a4ec21f32..225daea8b8c 100644 --- a/torchvision/csrc/cuda/DeformConv_cuda.cu +++ b/torchvision/csrc/cuda/DeformConv_cuda.cu @@ -79,6 +79,8 @@ using namespace at; const int CUDA_NUM_THREADS = 1024; const int kMaxGridNum = 65535; +const int kMaxParallelImgs = 32; + inline int GET_BLOCKS(const int N) { return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); @@ -193,62 +195,27 @@ void deformable_im2col( } } -void shape_check(at::Tensor input, at::Tensor offset, - at::Tensor weight, std::pair stride, std::pair pad, - std::pair dilation, int n_weight_grps, int n_offset_grps) { - - int in_h = input.size(2); - int in_w = input.size(3); - - int weight_h = weight.size(2); - int weight_w = weight.size(3); - - int stride_h = stride.first; - int stride_w = stride.second; - - int pad_h = pad.first; - int pad_w = pad.second; - - int dil_h = dilation.first; - int dil_w = dilation.second; - - int ker_h = dil_h * (weight_h - 1) + 1; - int ker_w = dil_w * (weight_w - 1) + 1; - int out_h = ((in_h + 2*pad_h - ker_h) / stride_h) + 1; - int out_w = ((in_w + 2*pad_w - ker_w) / stride_w) + 1; - - TORCH_CHECK(weight_h > 0 && weight_w > 0); - TORCH_CHECK(stride_h > 0 && stride_w > 0); - TORCH_CHECK(dil_h > 0 && dil_w > 0, "dil_h: ", dil_w, " dil_w: ", dil_h); - TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_w, " pad_w: ", pad_h); - - TORCH_CHECK(weight.size(1) * n_weight_grps == input.size(1)); - TORCH_CHECK(weight.size(0) % n_weight_grps == 0); - TORCH_CHECK(input.size(1) % n_offset_grps == 0); - - TORCH_CHECK((offset.size(0) == input.size(0)), "invalid batch size of offset"); - TORCH_CHECK((offset.size(1) == n_offset_grps * 2 * weight_h * weight_w), - "invalid number of channels of offset"); - TORCH_CHECK((offset.size(2) == out_h && offset.size(3) == out_w), - "offset output dims: (", offset.size(2), ", ", offset.size(3), - ") - output dims: (", out_h, ", ", out_w, ")"); - - TORCH_CHECK(out_h > 0 && out_w > 0, - "Calculated output size too small - out_h: ", out_h, " out_w: ", out_w); +int get_greatest_divisor_below_bound(int n, int bound) { + for(int k = bound; k > 1; --k) { + if(n % k == 0) { + return k; + } + } + return 1; } - -at::Tensor DCN_forward_cuda( +at::Tensor DeformConv2d_forward_cuda( const at::Tensor& input_param, - const at::Tensor& offset_param, const at::Tensor& weight_param, + const at::Tensor& offset_param, + const at::Tensor& bias, std::pair stride, std::pair pad, std::pair dilation, - int n_weight_grps, int n_offset_grps, int n_parallel_imgs) { + int n_weight_grps, int n_offset_grps) { at::Tensor input = input_param; - at::Tensor offset = offset_param; at::Tensor weight = weight_param; + at::Tensor offset = offset_param; TORCH_CHECK(input.ndimension() == 4); TORCH_CHECK(offset.ndimension() == 4); @@ -265,7 +232,7 @@ at::Tensor DCN_forward_cuda( int in_h = input.size(2); int in_w = input.size(3); - n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); + int n_parallel_imgs = get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); int out_channels = weight.size(0); int weight_h = weight.size(2); @@ -286,12 +253,10 @@ at::Tensor DCN_forward_cuda( int out_w = ((in_w + 2*pad_w - ker_w) / stride_w) + 1; - TORCH_CHECK(batch_sz % n_parallel_imgs == 0); - - TORCH_CHECK(weight_h > 0 && weight_w > 0, "weight_h: ", weight_w, " weight_w: ", weight_h); - TORCH_CHECK(stride_h > 0 && stride_w > 0, "stride_h: ", stride_w, " stride_w: ", stride_h); - TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_w, " pad_w: ", pad_h); - TORCH_CHECK(dil_h > 0 && dil_w > 0, "dil_h: ", dil_w, " dil_w: ", dil_h); + TORCH_CHECK(weight_h > 0 && weight_w > 0, "weight_h: ", weight_h, " weight_w: ", weight_w); + TORCH_CHECK(stride_h > 0 && stride_w > 0, "stride_h: ", stride_h, " stride_w: ", stride_w); + TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w); + TORCH_CHECK(dil_h > 0 && dil_w > 0, "dil_h: ", dil_h, " dil_w: ", dil_w); TORCH_CHECK(weight.size(1) * n_weight_grps == input.size(1)); TORCH_CHECK(weight.size(0) % n_weight_grps == 0); @@ -307,6 +272,7 @@ at::Tensor DCN_forward_cuda( auto out = at::zeros({batch_sz, out_channels, out_h, out_w}, input.options()); + // Separate batches into blocks out = out.view({batch_sz / n_parallel_imgs, n_parallel_imgs, out_channels, out_h, out_w}); input = input.view({batch_sz / n_parallel_imgs, n_parallel_imgs, in_channels, in_h, in_w}); @@ -337,7 +303,7 @@ at::Tensor DCN_forward_cuda( out.copy_(out_buf); out = out.view({batch_sz, out_channels, out_h, out_w}); - return out; + return out + bias.view({1, out_channels, 1, 1}); } @@ -542,7 +508,7 @@ void compute_grad_offset( std::tuple deform_conv_backward_input_cuda( - at::Tensor input, at::Tensor offset, at::Tensor weight, + at::Tensor input, at::Tensor weight, at::Tensor offset, at::Tensor grad_out, std::pair stride, std::pair pad, @@ -625,7 +591,7 @@ std::tuple deform_conv_backward_input_cuda( at::Tensor deform_conv_backward_parameters_cuda( - at::Tensor input, at::Tensor offset, at::Tensor weight, + at::Tensor input, at::Tensor weight, at::Tensor offset, at::Tensor grad_out, std::pair stride, std::pair pad, @@ -699,30 +665,34 @@ at::Tensor deform_conv_backward_parameters_cuda( } -std::tuple DCN_backward_cuda( +std::tuple DeformConv2d_backward_cuda( const at::Tensor& grad_out, const at::Tensor& input, - const at::Tensor& offset, const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& bias, std::pair stride, std::pair pad, std::pair dilation, - int groups, - int deformable_groups, - int n_parallel_imgs) { + int n_weight_grps, int n_offset_grps) { + const int batch_sz = input.size(0); + const int n_parallel_imgs = get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); auto grad_input_and_offset = deform_conv_backward_input_cuda( - input, offset, weight, grad_out, + input, weight, offset, grad_out, stride, pad, dilation, - groups, deformable_groups, n_parallel_imgs); + n_weight_grps, n_offset_grps, n_parallel_imgs); auto grad_input = std::get<0>(grad_input_and_offset); auto grad_offset = std::get<1>(grad_input_and_offset); auto grad_weight = deform_conv_backward_parameters_cuda( - input, offset, weight, grad_out, + input, weight, offset, grad_out, stride, pad, dilation, - groups, deformable_groups, n_parallel_imgs); + n_weight_grps, n_offset_grps, n_parallel_imgs); + + auto value = grad_out.sum({0, 2, 3}); + auto grad_bias = at::ones_like(bias) * value; - return {grad_input, grad_offset, grad_weight}; + return {grad_input, grad_weight, grad_offset, grad_bias}; } diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index 320a5f15702..fe7655a52aa 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -86,25 +86,25 @@ at::Tensor nms_cuda( const at::Tensor& scores, const float iou_threshold); -at::Tensor DCN_forward_cuda( +at::Tensor DeformConv2d_forward_cuda( const at::Tensor& input, + const at::Tensor& weight, const at::Tensor& offset, - const at::Tensor& weights, + const at::Tensor& bias, std::pair stride, std::pair pad, std::pair dilation, int groups, - int deformable_groups, - int n_parallel_imgs); + int deformable_groups); -std::tuple DCN_backward_cuda( +std::tuple DeformConv2d_backward_cuda( const at::Tensor& grad_out, const at::Tensor& input, + const at::Tensor& weight, const at::Tensor& offset, - const at::Tensor& weights, + const at::Tensor& bias, std::pair stride, std::pair pad, std::pair dilation, int groups, - int deformable_groups, - int n_parallel_imgs); + int deformable_groups); diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index b97d4c16f8f..12cc7c87675 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -48,5 +48,5 @@ static auto registry = .op("torchvision::_new_empty_tensor_op", &new_empty_tensor) .op("torchvision::ps_roi_align", &ps_roi_align) .op("torchvision::ps_roi_pool", &ps_roi_pool) - .op("torchvision::deform_conv", &deform_conv) + .op("torchvision::deform_conv2d", &deform_conv2d) .op("torchvision::_cuda_version", &_cuda_version); diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 61b6e45c218..0ff2b0be2ce 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -1,9 +1,6 @@ from .boxes import nms, box_iou -<<<<<<< HEAD from .new_empty_tensor import _new_empty_tensor -======= -from .deform_conv import deform_conv, DeformConv ->>>>>>> Add Deformable Convolution operation. +from .deform_conv import deform_conv2d, DeformConv2d from .roi_align import roi_align, RoIAlign from .roi_pool import roi_pool, RoIPool from .ps_roi_align import ps_roi_align, PSRoIAlign @@ -17,7 +14,7 @@ __all__ = [ - 'deform_conv', 'DeformConv', 'nms', 'roi_align', 'RoIAlign', 'roi_pool', + 'deform_conv2d', 'DeformConv2d', 'nms', 'roi_align', 'RoIAlign', 'roi_pool', 'RoIPool', '_new_empty_tensor', 'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool', 'PSRoIPool', 'MultiScaleRoIAlign', 'FeaturePyramidNetwork' ] diff --git a/torchvision/ops/deform_conv.py b/torchvision/ops/deform_conv.py index 87e9a87a89a..bb72a390f87 100644 --- a/torchvision/ops/deform_conv.py +++ b/torchvision/ops/deform_conv.py @@ -1,70 +1,128 @@ -from typing import Tuple +import math import torch from torch import nn, Tensor +from torch.nn import init +from torch.nn.parameter import Parameter from torch.nn.modules.utils import _pair -from torch.jit.annotations import List +from torch.jit.annotations import Tuple -def deform_conv(input, offset, weight, stride=(1, 1), pad=(0, 0), dilation=(1, 1), n_parallel_imgs=64): - # type: (Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int], int) -> Tensor +def deform_conv2d(input, weight, offset, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1)): + # type: (Tensor, Tensor, Tensor, Optional[Tensor], Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor """ Performs Deformable Convolution, described in Deformable Convolutional Networks Arguments: - input (Tensor[batch_sz, in_channels, in_h, in_w]): input tensor - offset (Tensor[batch_sz, 2 * n_offset_grps * weight_h * weight_w, out_h, out_w]) - weight (Tensor[out_channels, in_channels // n_weight_grps, weight_h, weight_w]): - convolution weights, with n_weight_grps different connection groups - stride (int or Tuple[int, int]): distance between convolution centers - pad (int or Tuple[int, int]): height/width of padding of zeroes around each image - dilation (int or Tuple[int, int]): point distance in convolution grid - n_parallel_imgs (int): Number of images to be processed at once; does not change - behavior, only used for performance purposes + input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor + weight (Tensor[out_channels, in_channels // groups, kernel_height, kernel_width]): + convolution weights, split into groups of size (in_channels // groups) + offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width, + out_height, out_width]): offsets to be applied for each position in the + convolution kernel. + bias (Tensor[out_channels]): optional bias of shape (out_channels,). Default: None + stride (int or Tuple[int, int]): distance between convolution centers. Default: 1 + padding (int or Tuple[int, int]): height/width of padding of zeroes around + each image. Default: 0 + dilation (int or Tuple[int, int]): the spacing between kernel elements. Default: 1 Returns: output (Tensor[batch_sz, out_channels, out_h, out_w]): result of convolution """ - stride_h, stride_w = stride - pad_h, pad_w = pad - dil_h, dil_w = dilation + out_channels = weight.shape[0] + if bias is None: + bias = torch.zeros(out_channels, device=input.device, dtype=input.dtype) + + stride_h, stride_w = _pair(stride) + pad_h, pad_w = _pair(padding) + dil_h, dil_w = _pair(dilation) weights_h, weights_w = weight.shape[-2:] _, n_in_channels, in_h, in_w = input.shape n_offset_grps = offset.shape[1] // (2 * weights_h * weights_w) n_weight_grps = n_in_channels // weight.shape[1] - return torch.ops.torchvision.deform_conv( + return torch.ops.torchvision.deform_conv2d( input, - offset, weight, - *stride, - *pad, - *dilation, + offset, + bias, + stride_h, stride_w, + pad_h, pad_w, + dil_h, dil_w, n_weight_grps, - n_offset_grps, - n_parallel_imgs) + n_offset_grps) -class DeformConv(nn.Module): +class DeformConv2d(nn.Module): """ - See deform_conv + See deform_conv2d """ - def __init__(self, stride=1, pad=0, dilation=1, n_parallel_imgs=64): - super(DeformConv, self).__init__() + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, + dilation=1, groups=1, offset_groups=1, bias=True): + super(DeformConv2d, self).__init__() + + if in_channels % groups != 0: + raise ValueError('in_channels must be divisible by groups') + if in_channels % offset_groups != 0: + raise ValueError('in_channels must be divisible by offset_groups') + if out_channels % groups != 0: + raise ValueError('out_channels must be divisible by groups') + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) self.stride = _pair(stride) - self.pad = _pair(pad) + self.padding = _pair(padding) self.dilation = _pair(dilation) - self.n_parallel_imgs = n_parallel_imgs + self.groups = groups + self.offset_groups = offset_groups + + self.weight = Parameter(torch.empty(out_channels, in_channels // groups, kernel_size[0], kernel_size[1])) + + self.offset_conv = nn.Conv2d( + self.in_channels, + offset_groups * 2 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation) + + if bias: + self.bias = Parameter(torch.empty(out_channels)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + init.zeros_(self.offset_conv.weight) + init.zeros_(self.offset_conv.bias) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + + def forward(self, input): + offset = self.offset_conv.to(device=input.device, dtype=input.dtype)(input) + weight = self.weight.to(device=input.device, dtype=input.dtype) + bias = self.bias.to(device=input.device, dtype=input.dtype) if self.bias is not None else self.bias - def forward(self, input, offset, weight): - return deform_conv(input, offset, weight, stride=self.stride, pad=self.pad, - dilation=self.dilation, n_parallel_imgs=self.n_parallel_imgs) + return deform_conv2d(input, weight, offset, bias, stride=self.stride, + padding=self.padding, dilation=self.dilation) def __repr__(self): - tmpstr = self.__class__.__name__ + '(' - tmpstr += 'output_size=' + str(self.output_size) - tmpstr += ', spatial_scale=' + str(self.spatial_scale) - tmpstr += ')' - return tmpstr + s = self.__class__.__name__ + '(' + s += '{in_channels}' + s += ', {out_channels}' + s += ', kernel_size={kernel_size}' + s += ', stride={stride}' + s += ', padding={padding}' if self.padding != (0, 0) else '' + s += ', dilation={dilation}' if self.dilation != (1, 1) else '' + s += ', groups={groups}' if self.groups != 1 else '' + s += ', offset_groups={offset_groups}' if self.offset_groups != 1 else '' + s += ', bias=False' if self.bias is None else '' + s += ')' + return s.format(**self.__dict__) From a05f3f507220a5d22f80581e5f7d6798a1a3634b Mon Sep 17 00:00:00 2001 From: Pedro Freire Date: Wed, 20 Nov 2019 10:44:37 +0000 Subject: [PATCH 3/9] Apply clang-format in DeformConv files. --- torchvision/csrc/DeformConv.h | 159 +++-- torchvision/csrc/cpu/DeformConv_cpu.cpp | 754 +++++++++++++++------ torchvision/csrc/cpu/vision_cpu.h | 3 +- torchvision/csrc/cuda/DeformConv_cuda.cu | 792 ++++++++++++++++------- torchvision/csrc/cuda/vision_cuda.h | 3 +- torchvision/csrc/vision.cpp | 2 +- 6 files changed, 1219 insertions(+), 494 deletions(-) diff --git a/torchvision/csrc/DeformConv.h b/torchvision/csrc/DeformConv.h index b057f611380..7ce41824cab 100644 --- a/torchvision/csrc/DeformConv.h +++ b/torchvision/csrc/DeformConv.h @@ -7,32 +7,49 @@ #endif at::Tensor DeformConv2d_forward( - const Tensor& input, - const Tensor& weight, - const Tensor& offset, - const Tensor& bias, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& bias, const std::pair& stride, const std::pair& padding, const std::pair& dilation, - const int groups, const int offset_groups) { + const int groups, + const int offset_groups) { if (input.type().is_cuda()) { #ifdef WITH_CUDA - return DeformConv2d_forward_cuda(input.contiguous(), weight.contiguous(), offset.contiguous(), - bias.contiguous(), stride, padding, dilation, groups, offset_groups); + return DeformConv2d_forward_cuda( + input.contiguous(), + weight.contiguous(), + offset.contiguous(), + bias.contiguous(), + stride, + padding, + dilation, + groups, + offset_groups); #else AT_ERROR("Not compiled with GPU support"); #endif } - return DeformConv2d_forward_cpu(input.contiguous(), weight.contiguous(), offset.contiguous(), - bias.contiguous(), stride, padding, dilation, groups, offset_groups); + return DeformConv2d_forward_cpu( + input.contiguous(), + weight.contiguous(), + offset.contiguous(), + bias.contiguous(), + stride, + padding, + dilation, + groups, + offset_groups); } std::tuple DeformConv2d_backward( const at::Tensor& grad, - const Tensor& input, - const Tensor& weight, - const Tensor& offset, - const Tensor& bias, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& bias, const std::pair& stride, const std::pair& padding, const std::pair& dilation, @@ -40,14 +57,32 @@ std::tuple DeformConv2d_backward const int offset_groups) { if (grad.type().is_cuda()) { #ifdef WITH_CUDA - return DeformConv2d_backward_cuda(grad.contiguous(), input.contiguous(), weight.contiguous(), offset.contiguous(), - bias.contiguous(), stride, padding, dilation, groups, offset_groups); + return DeformConv2d_backward_cuda( + grad.contiguous(), + input.contiguous(), + weight.contiguous(), + offset.contiguous(), + bias.contiguous(), + stride, + padding, + dilation, + groups, + offset_groups); #else AT_ERROR("Not compiled with GPU support"); #endif } - return DeformConv2d_backward_cpu(grad.contiguous(), input.contiguous(), weight.contiguous(), offset.contiguous(), - bias.contiguous(), stride, padding, dilation, groups, offset_groups); + return DeformConv2d_backward_cpu( + grad.contiguous(), + input.contiguous(), + weight.contiguous(), + offset.contiguous(), + bias.contiguous(), + stride, + padding, + dilation, + groups, + offset_groups); } using namespace at; @@ -56,7 +91,8 @@ using torch::autograd::AutogradContext; using torch::autograd::Variable; using torch::autograd::variable_list; -class DeformConv2dFunction : public torch::autograd::Function { +class DeformConv2dFunction + : public torch::autograd::Function { public: static variable_list forward( AutogradContext* ctx, @@ -64,17 +100,24 @@ class DeformConv2dFunction : public torch::autograd::Functionsave_for_backward({input, weight, offset, bias}); ctx->saved_data["stride_h"] = stride_h; @@ -86,7 +129,9 @@ class DeformConv2dFunction : public torch::autograd::Functionsaved_data["groups"] = groups; ctx->saved_data["offset_groups"] = offset_groups; - return {output,}; + return { + output, + }; } static variable_list backward( @@ -107,34 +152,64 @@ class DeformConv2dFunction : public torch::autograd::Functionsaved_data["groups"].toInt(); auto offset_groups = ctx->saved_data["offset_groups"].toInt(); - auto grads = DeformConv2d_backward(grad_output[0], - input, weight, offset, bias, + auto grads = DeformConv2d_backward( + grad_output[0], + input, + weight, + offset, + bias, {stride_h, stride_w}, {pad_h, pad_w}, {dilation_h, dilation_w}, - groups, offset_groups); + groups, + offset_groups); auto grad_input = std::get<0>(grads); auto grad_weight = std::get<1>(grads); auto grad_offset = std::get<2>(grads); auto grad_bias = std::get<3>(grads); - return {grad_input, grad_weight, grad_offset, - grad_bias, Variable(), Variable(), - Variable(), Variable(), Variable(), - Variable(), Variable(), Variable(),}; + return { + grad_input, + grad_weight, + grad_offset, + grad_bias, + Variable(), + Variable(), + Variable(), + Variable(), + Variable(), + Variable(), + Variable(), + Variable(), + }; } }; -Tensor deform_conv2d( - const Tensor& input, - const Tensor& weight, - const Tensor& offset, - const Tensor& bias, - int64_t stride_h, int64_t stride_w, - int64_t pad_h, int64_t pad_w, - int64_t dilation_h, int64_t dilation_w, - int64_t groups, int64_t offset_groups) { - auto result = DeformConv2dFunction::apply(input, weight, offset, bias, stride_h, stride_w, pad_h, pad_w, - dilation_h, dilation_w, groups, offset_groups); +at::Tensor deform_conv2d( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t offset_groups) { + auto result = DeformConv2dFunction::apply( + input, + weight, + offset, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups); return result[0]; } diff --git a/torchvision/csrc/cpu/DeformConv_cpu.cpp b/torchvision/csrc/cpu/DeformConv_cpu.cpp index ecbe66f6e1d..94590b341f6 100644 --- a/torchvision/csrc/cpu/DeformConv_cpu.cpp +++ b/torchvision/csrc/cpu/DeformConv_cpu.cpp @@ -1,5 +1,6 @@ /*! - ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + ******************* BEGIN Caffe Copyright Notice and Disclaimer + ***************** * * COPYRIGHT * @@ -23,22 +24,22 @@ * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR - * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE + *FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + *DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + *SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + *CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + *OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * * CONTRIBUTION AGREEMENT * @@ -46,7 +47,8 @@ * or otherwise, the contributor releases their content to the * license and copyright terms herein. * - ***************** END Caffe Copyright Notice and Disclaimer ******************** + ***************** END Caffe Copyright Notice and Disclaimer + ********************* * * Copyright (c) 2018 Microsoft * Licensed under The MIT License [see LICENSE for details] @@ -58,10 +60,11 @@ * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng */ -// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu - -// modified from https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp +// modified from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu +// modified from +// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp #include #include @@ -71,11 +74,15 @@ using namespace at; - const int kMaxParallelImgs = 32; template -static scalar_t bilinear_interpolate(const scalar_t *in, const int height, const int width, scalar_t h, scalar_t w) { +static scalar_t bilinear_interpolate( + const scalar_t* in, + const int height, + const int width, + scalar_t h, + scalar_t w) { if (h <= -1 || height <= h || w <= -1 || width <= w) { return 0; } @@ -109,14 +116,27 @@ static scalar_t bilinear_interpolate(const scalar_t *in, const int height, const } template -static void deformable_im2col_kernel(const int n, const scalar_t* input, const scalar_t* offset, - const int height, const int width, const int weight_h, const int weight_w, - const int pad_h, const int pad_w, const int stride_h, const int stride_w, - const int dil_h, const int dil_w, - const int batch_sz, const int n_in_channels, const int n_offset_grps, - const int out_h, const int out_w, - scalar_t* columns) { - for(int index = 0; index != n; ++index) { +static void deformable_im2col_kernel( + const int n, + const scalar_t* input, + const scalar_t* offset, + const int height, + const int width, + const int weight_h, + const int weight_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dil_h, + const int dil_w, + const int batch_sz, + const int n_in_channels, + const int n_offset_grps, + const int out_h, + const int out_w, + scalar_t* columns) { + for (int index = 0; index != n; ++index) { const int out_x = index % out_w; const int out_y = (index / out_w) % out_h; const int out_b = (index / (out_w * out_h)) % batch_sz; @@ -126,21 +146,24 @@ static void deformable_im2col_kernel(const int n, const scalar_t* input, const s int c_per_offset_grp = n_in_channels / n_offset_grps; const int grp_idx = in_c / c_per_offset_grp; - auto columns_ptr = columns + (out_c * (batch_sz * out_h * out_w) - + out_b * (out_h * out_w) - + out_y * out_w - + out_x); + auto columns_ptr = columns + + (out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) + + out_y * out_w + out_x); - auto input_ptr = input + (out_b * (n_in_channels * height * width) - + in_c * (height * width)); + auto input_ptr = input + + (out_b * (n_in_channels * height * width) + in_c * (height * width)); - auto offset_ptr = offset + (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h * out_w; + auto offset_ptr = offset + + (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h * + out_w; for (int i = 0; i < weight_h; ++i) { for (int j = 0; j < weight_w; ++j) { const int offset_idx = 2 * (i * weight_w + j); - const scalar_t offset_h = offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x]; - const scalar_t offset_w = offset_ptr[(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x]; + const scalar_t offset_h = + offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x]; + const scalar_t offset_w = offset_ptr + [(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x]; const scalar_t y = (out_y * stride_h - pad_h) + i * dil_h + offset_h; const scalar_t x = (out_x * stride_w - pad_w) + j * dil_w + offset_w; *columns_ptr = bilinear_interpolate(input_ptr, height, width, y, x); @@ -151,14 +174,24 @@ static void deformable_im2col_kernel(const int n, const scalar_t* input, const s } static void deformable_im2col( - const at::Tensor input, const at::Tensor data_offset, int n_in_channels, - int height, int width, - int weight_h, int weight_w, - int pad_h, int pad_w, - int stride_h, int stride_w, - int dil_h, int dil_w, - int out_h, int out_w, - int parallel_imgs, int deformable_group, at::Tensor data_col) { + const at::Tensor input, + const at::Tensor data_offset, + int n_in_channels, + int height, + int width, + int weight_h, + int weight_w, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dil_h, + int dil_w, + int out_h, + int out_w, + int parallel_imgs, + int deformable_group, + at::Tensor data_col) { int num_kernels = n_in_channels * out_h * out_w * parallel_imgs; AT_DISPATCH_FLOATING_TYPES_AND_HALF( @@ -167,16 +200,28 @@ static void deformable_im2col( num_kernels, input.data_ptr(), data_offset.data_ptr(), - height, width, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, dil_h, dil_w, - parallel_imgs, n_in_channels, deformable_group, - out_h, out_w, + height, + width, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dil_h, + dil_w, + parallel_imgs, + n_in_channels, + deformable_group, + out_h, + out_w, data_col.data_ptr()); })); } static int get_greatest_divisor_below_bound(int n, int bound) { - for(int k = bound; k > 1; --k) { - if(n % k == 0) { + for (int k = bound; k > 1; --k) { + if (n % k == 0) { return k; } } @@ -191,7 +236,8 @@ at::Tensor DeformConv2d_forward_cpu( std::pair stride, std::pair pad, std::pair dilation, - int n_weight_grps, int n_offset_grps) { + int n_weight_grps, + int n_offset_grps) { at::Tensor input = input_param; at::Tensor offset = offset_param; at::Tensor weight = weight_param; @@ -209,7 +255,8 @@ at::Tensor DeformConv2d_forward_cpu( int in_h = input.size(2); int in_w = input.size(3); - int n_parallel_imgs = get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); + int n_parallel_imgs = + get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); // Unpack shapes and args int out_channels = weight.size(0); @@ -227,12 +274,21 @@ at::Tensor DeformConv2d_forward_cpu( int ker_h = dil_h * (weight_h - 1) + 1; int ker_w = dil_w * (weight_w - 1) + 1; - int out_h = ((in_h + 2*pad_h - ker_h) / stride_h) + 1; - int out_w = ((in_w + 2*pad_w - ker_w) / stride_w) + 1; - - - TORCH_CHECK(weight_h > 0 && weight_w > 0, "weight_h: ", weight_h, " weight_w: ", weight_w); - TORCH_CHECK(stride_h > 0 && stride_w > 0, "stride_h: ", stride_h, " stride_w: ", stride_w); + int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; + int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1; + + TORCH_CHECK( + weight_h > 0 && weight_w > 0, + "weight_h: ", + weight_h, + " weight_w: ", + weight_w); + TORCH_CHECK( + stride_h > 0 && stride_w > 0, + "stride_h: ", + stride_h, + " stride_w: ", + stride_w); TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w); TORCH_CHECK(dil_h > 0 && dil_w > 0, "dil_h: ", dil_h, " dil_w: ", dil_w); @@ -240,43 +296,107 @@ at::Tensor DeformConv2d_forward_cpu( TORCH_CHECK(weight.size(0) % n_weight_grps == 0); TORCH_CHECK(input.size(1) % n_offset_grps == 0); - TORCH_CHECK((offset.size(0) == input.size(0)), "invalid batch size of offset"); - TORCH_CHECK((offset.size(1) == n_offset_grps * 2 * weight_h * weight_w), - "got: ", offset.size(1), " expected: ", n_offset_grps * 2 * weight_h * weight_w); - TORCH_CHECK((offset.size(2) == out_h && offset.size(3) == out_w), - "offset output dims: (", offset.size(2), ", ", offset.size(3), ") - ", - "computed output dims: (", out_h, ", ", out_w, ")"); - TORCH_CHECK(out_h > 0 && out_w > 0, "Calculated output size too small - out_h: ", out_h, " out_w: ", out_w); - + TORCH_CHECK( + (offset.size(0) == input.size(0)), "invalid batch size of offset"); + TORCH_CHECK( + (offset.size(1) == n_offset_grps * 2 * weight_h * weight_w), + "got: ", + offset.size(1), + " expected: ", + n_offset_grps * 2 * weight_h * weight_w); + TORCH_CHECK( + (offset.size(2) == out_h && offset.size(3) == out_w), + "offset output dims: (", + offset.size(2), + ", ", + offset.size(3), + ") - ", + "computed output dims: (", + out_h, + ", ", + out_w, + ")"); + TORCH_CHECK( + out_h > 0 && out_w > 0, + "Calculated output size too small - out_h: ", + out_h, + " out_w: ", + out_w); auto out = at::zeros({batch_sz, out_channels, out_h, out_w}, input.options()); // Separate batches into blocks - out = out.view({batch_sz / n_parallel_imgs, n_parallel_imgs, out_channels, out_h, out_w}); - input = input.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); - offset = offset.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); - at::Tensor out_buf = at::zeros({batch_sz / n_parallel_imgs, out_channels, n_parallel_imgs * out_h, out_w}, out.options()); + out = out.view({batch_sz / n_parallel_imgs, + n_parallel_imgs, + out_channels, + out_h, + out_w}); + input = input.view( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + offset = offset.view({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + at::Tensor out_buf = at::zeros( + {batch_sz / n_parallel_imgs, + out_channels, + n_parallel_imgs * out_h, + out_w}, + out.options()); // Separate channels into convolution groups - out_buf = out_buf.view({out_buf.size(0), n_weight_grps, out_buf.size(1) / n_weight_grps, out_buf.size(2), out_buf.size(3)}); - weight = weight.view({n_weight_grps, weight.size(0) / n_weight_grps, weight.size(1), weight.size(2), weight.size(3)}); + out_buf = out_buf.view({out_buf.size(0), + n_weight_grps, + out_buf.size(1) / n_weight_grps, + out_buf.size(2), + out_buf.size(3)}); + weight = weight.view({n_weight_grps, + weight.size(0) / n_weight_grps, + weight.size(1), + weight.size(2), + weight.size(3)}); // Sample points and perform convolution - auto columns = at::zeros({n_in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w}, input.options()); + auto columns = at::zeros( + {n_in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w}, + input.options()); for (int b = 0; b < batch_sz / n_parallel_imgs; b++) { - deformable_im2col(input[b], offset[b], n_in_channels, in_h, - in_w, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, dil_h, - dil_w, out_h, out_w, n_parallel_imgs, n_offset_grps, columns); - - columns = columns.view({n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + deformable_im2col( + input[b], + offset[b], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dil_h, + dil_w, + out_h, + out_w, + n_parallel_imgs, + n_offset_grps, + columns); + + columns = columns.view( + {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); for (int g = 0; g < n_weight_grps; g++) { - out_buf[b][g] = out_buf[b][g].flatten(1) - .addmm_(weight[g].flatten(1), columns[g]) - .view_as(out_buf[b][g]); + out_buf[b][g] = out_buf[b][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(out_buf[b][g]); } } - out_buf = out_buf.view({batch_sz / n_parallel_imgs, out_channels, n_parallel_imgs, out_h, out_w}); + out_buf = out_buf.view({batch_sz / n_parallel_imgs, + out_channels, + n_parallel_imgs, + out_h, + out_w}); out_buf.transpose_(1, 2); out.copy_(out_buf); out = out.view({batch_sz, out_channels, out_h, out_w}); @@ -284,19 +404,28 @@ at::Tensor DeformConv2d_forward_cpu( return out + bias.view({1, out_channels, 1, 1}); } - template static void deformable_col2im_kernel( - const int n, const scalar_t *col, const scalar_t *offset, - const int channels, const int height, const int width, - const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int batch_sz, const int n_offset_grps, - const int out_h, const int out_w, - scalar_t *grad_im) { - for(int index = 0; index != n; ++index) { + const int n, + const scalar_t* col, + const scalar_t* offset, + const int channels, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int batch_sz, + const int n_offset_grps, + const int out_h, + const int out_w, + scalar_t* grad_im) { + for (int index = 0; index != n; ++index) { const int j = (index / (out_w * out_h * batch_sz)) % kernel_w; const int i = (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h; const int c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h); @@ -308,9 +437,13 @@ static void deformable_col2im_kernel( int out_y = (index / out_w) % out_h; int b = (index / (out_w * out_h)) % batch_sz; - auto offset_ptr = offset + (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * out_h * out_w; - const int offset_h_ptr = ((2 * (i * kernel_w + j)) * out_h + out_y) * out_w + out_x; - const int offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * out_h + out_y) * out_w + out_x; + auto offset_ptr = offset + + (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * out_h * + out_w; + const int offset_h_ptr = + ((2 * (i * kernel_w + j)) * out_h + out_y) * out_w + out_x; + const int offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * out_h + out_y) * out_w + out_x; const scalar_t offset_h = offset_ptr[offset_h_ptr]; const scalar_t offset_w = offset_ptr[offset_w_ptr]; const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; @@ -320,10 +453,8 @@ static void deformable_col2im_kernel( for (int dx = -1; dx <= 1; dx++) { int yp = int(y) + dy; int xp = int(x) + dx; - if (0 <= yp && yp < height && - 0 <= xp && xp < width && - abs(y - yp) < 1 && - abs(x - xp) < 1) { + if (0 <= yp && yp < height && 0 <= xp && xp < width && + abs(y - yp) < 1 && abs(x - xp) < 1) { int grad_pos = ((b * channels + c) * height + yp) * width + xp; scalar_t weight = (1 - abs(y - yp)) * (1 - abs(x - xp)); grad_im[grad_pos] += weight * col[index]; @@ -334,16 +465,28 @@ static void deformable_col2im_kernel( } static void compute_grad_input( - const at::Tensor columns, const at::Tensor offset, const int channels, - const int height, const int width, const int weight_h, - const int weight_w, const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int parallel_imgs, const int n_offset_grps, + const at::Tensor columns, + const at::Tensor offset, + const int channels, + const int height, + const int width, + const int weight_h, + const int weight_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int parallel_imgs, + const int n_offset_grps, at::Tensor grad_im) { - int out_h = (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; - int out_w = (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; - int num_kernels = channels * weight_h * weight_w * out_h * out_w * parallel_imgs; + int out_h = + (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; + int out_w = + (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; + int num_kernels = + channels * weight_h * weight_w * out_h * out_w * parallel_imgs; AT_DISPATCH_FLOATING_TYPES_AND_HALF( columns.scalar_type(), "deformable_col2im", ([&] { @@ -351,16 +494,33 @@ static void compute_grad_input( num_kernels, columns.data_ptr(), offset.data_ptr(), - channels, height, width, - weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, - parallel_imgs, n_offset_grps, out_h, out_w, + channels, + height, + width, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + parallel_imgs, + n_offset_grps, + out_h, + out_w, grad_im.data_ptr()); })); } - template -static scalar_t get_coordinate_weight(const scalar_t *im_data, const int height, const int width, scalar_t y, scalar_t x, bool is_y_direction) { +static scalar_t get_coordinate_weight( + const scalar_t* im_data, + const int height, + const int width, + scalar_t y, + scalar_t x, + bool is_y_direction) { int y_l = floor(y); int x_l = floor(x); int y_h = y_l + 1; @@ -386,18 +546,30 @@ static scalar_t get_coordinate_weight(const scalar_t *im_data, const int height, } } - template -static void deformable_col2im_coord_kernel(const int n, const scalar_t *col, - const scalar_t *im, const scalar_t *offset, - const int channels, const int height, const int width, - const int weight_h, const int weight_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int batch_sz, const int offset_channels, const int n_offset_grps, - const int out_h, const int out_w, scalar_t *grad_offset) { - for(int index = 0; index != n; ++index) { +static void deformable_col2im_coord_kernel( + const int n, + const scalar_t* col, + const scalar_t* im, + const scalar_t* offset, + const int channels, + const int height, + const int width, + const int weight_h, + const int weight_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int batch_sz, + const int offset_channels, + const int n_offset_grps, + const int out_h, + const int out_w, + scalar_t* grad_offset) { + for (int index = 0; index != n; ++index) { scalar_t val = 0; int w = index % out_w; int h = (index / out_w) % out_h; @@ -409,9 +581,14 @@ static void deformable_col2im_coord_kernel(const int n, const scalar_t *col, int c_per_offset_grp = channels / n_offset_grps; - auto col_ptr = col + offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * out_w * out_h; - auto im_ptr = im + (b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width; - auto offset_ptr = offset + (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * out_h * out_w; + auto col_ptr = col + + offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * out_w * + out_h; + auto im_ptr = im + + (b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width; + auto offset_ptr = offset + + (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * out_h * + out_w; const int offset_c = c - offset_grp * 2 * weight_h * weight_w; const int is_y_direction = offset_c % 2 == 0; @@ -425,15 +602,18 @@ static void deformable_col2im_coord_kernel(const int n, const scalar_t *col, int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w; int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h; - const int offset_h_idx = (((2 * (i * weight_w + j)) * out_h + out_y) * out_w + out_x); - const int offset_w_idx = (((2 * (i * weight_w + j) + 1) * out_h + out_y) * out_w + out_x); + const int offset_h_idx = + (((2 * (i * weight_w + j)) * out_h + out_y) * out_w + out_x); + const int offset_w_idx = + (((2 * (i * weight_w + j) + 1) * out_h + out_y) * out_w + out_x); const scalar_t offset_h = offset_ptr[offset_h_idx]; const scalar_t offset_w = offset_ptr[offset_w_idx]; scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; - const scalar_t weight = get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction); + const scalar_t weight = + get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction); val += weight * col_ptr[col_pos]; im_ptr += height * width; } @@ -443,14 +623,29 @@ static void deformable_col2im_coord_kernel(const int n, const scalar_t *col, } static void compute_grad_offset( - const at::Tensor columns, const at::Tensor input, const at::Tensor offset, - const int channels, const int height, const int width, const int weight_h, - const int weight_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, const int dilation_h, const int dilation_w, - const int parallel_imgs, const int n_offset_grps, at::Tensor grad_offset) { - int out_h = (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; - int out_w = (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; - int num_kernels = out_h * out_w * 2 * weight_h * weight_w * n_offset_grps * parallel_imgs; + const at::Tensor columns, + const at::Tensor input, + const at::Tensor offset, + const int channels, + const int height, + const int width, + const int weight_h, + const int weight_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int parallel_imgs, + const int n_offset_grps, + at::Tensor grad_offset) { + int out_h = + (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; + int out_w = + (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; + int num_kernels = + out_h * out_w * 2 * weight_h * weight_w * n_offset_grps * parallel_imgs; AT_DISPATCH_FLOATING_TYPES_AND_HALF( columns.scalar_type(), "deformable_col2im_coord", ([&] { @@ -459,24 +654,37 @@ static void compute_grad_offset( columns.data_ptr(), input.data_ptr(), offset.data_ptr(), - channels, height, width, weight_h, - weight_w, pad_h, pad_w, stride_h, stride_w, - dilation_h, dilation_w, - parallel_imgs, 2 * weight_h * weight_w * n_offset_grps, n_offset_grps, - out_h, out_w, + channels, + height, + width, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + parallel_imgs, + 2 * weight_h * weight_w * n_offset_grps, + n_offset_grps, + out_h, + out_w, grad_offset.data_ptr()); })); } - static std::tuple deform_conv2d_backward_input_cpu( - at::Tensor input, at::Tensor weight, at::Tensor offset, + at::Tensor input, + at::Tensor weight, + at::Tensor offset, at::Tensor grad_out, std::pair stride, std::pair pad, std::pair dilation, - int n_weight_grps, int n_offset_grps, int n_parallel_imgs) { - + int n_weight_grps, + int n_offset_grps, + int n_parallel_imgs) { int batch_sz = input.size(0); int n_in_channels = input.size(1); int in_h = input.size(2); @@ -502,63 +710,122 @@ static std::tuple deform_conv2d_backward_input_cpu( auto grad_input = at::zeros_like(input); auto grad_offset = at::zeros_like(offset); - auto columns = at::zeros({n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, input.options()); + auto columns = at::zeros( + {n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, + input.options()); // Separate into blocks - grad_input = grad_input.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); - input = input.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); - grad_offset = grad_offset.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); - offset = offset.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); - - grad_out = grad_out.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_out_channels, out_h, out_w}); + grad_input = grad_input.view( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + input = input.view( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + grad_offset = grad_offset.view({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + offset = offset.view({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + + grad_out = grad_out.view({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_out_channels, + out_h, + out_w}); grad_out.transpose_(1, 2); - grad_out = grad_out.view( - {grad_out.size(0), n_weight_grps, grad_out.size(1) / n_weight_grps, - grad_out.size(2), grad_out.size(3), grad_out.size(4)}); + grad_out = grad_out.view({grad_out.size(0), + n_weight_grps, + grad_out.size(1) / n_weight_grps, + grad_out.size(2), + grad_out.size(3), + grad_out.size(4)}); for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { // Separate into weight groups - columns = columns.view({n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); - weight = weight.view({n_weight_grps, weight.size(0) / n_weight_grps, weight.size(1), weight.size(2), weight.size(3)}); + columns = columns.view( + {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + weight = weight.view({n_weight_grps, + weight.size(0) / n_weight_grps, + weight.size(1), + weight.size(2), + weight.size(3)}); for (int g = 0; g < n_weight_grps; g++) { - columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1)); + columns[g] = columns[g].addmm_( + weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1)); } - columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)}); - - compute_grad_offset(columns, input[elt], offset[elt], n_in_channels, - in_h, in_w, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, - dil_h, dil_w, n_parallel_imgs, n_offset_grps, - grad_offset[elt]); - - compute_grad_input(columns, offset[elt], n_in_channels, in_h, - in_w, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, dil_h, - dil_w, n_parallel_imgs, n_offset_grps, grad_input[elt]); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + + compute_grad_offset( + columns, + input[elt], + offset[elt], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dil_h, + dil_w, + n_parallel_imgs, + n_offset_grps, + grad_offset[elt]); + + compute_grad_input( + columns, + offset[elt], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dil_h, + dil_w, + n_parallel_imgs, + n_offset_grps, + grad_input[elt]); } - grad_out = grad_out.view( - {grad_out.size(0), grad_out.size(1) * grad_out.size(2), - grad_out.size(3), grad_out.size(4), grad_out.size(5)}); + grad_out = grad_out.view({grad_out.size(0), + grad_out.size(1) * grad_out.size(2), + grad_out.size(3), + grad_out.size(4), + grad_out.size(5)}); grad_out.transpose_(1, 2); grad_out = grad_out.view({batch_sz, n_out_channels, out_h, out_w}); grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w}); input = input.view({batch_sz, n_in_channels, in_h, in_w}); - grad_offset = grad_offset.view({batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); - offset = offset.view({batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + grad_offset = grad_offset.view( + {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + offset = offset.view( + {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); return {grad_input, grad_offset}; } - - static at::Tensor deform_conv2d_backward_parameters_cpu( - at::Tensor input, at::Tensor weight, at::Tensor offset, + at::Tensor input, + at::Tensor weight, + at::Tensor offset, at::Tensor grad_out, std::pair stride, std::pair pad, std::pair dilation, - int n_weight_grps, int n_offset_grps, int n_parallel_imgs) { - + int n_weight_grps, + int n_offset_grps, + int n_parallel_imgs) { int batch_sz = input.size(0); int n_in_channels = input.size(1); int in_h = input.size(2); @@ -582,50 +849,95 @@ static at::Tensor deform_conv2d_backward_parameters_cpu( long out_h = grad_out.size(2); long out_w = grad_out.size(3); - auto grad_weight = at::zeros_like(weight);; - auto columns = at::zeros({n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, input.options()); - - grad_out = grad_out.view({batch_sz / n_parallel_imgs, n_parallel_imgs, - n_out_channels, out_h, out_w}); + auto grad_weight = at::zeros_like(weight); + ; + auto columns = at::zeros( + {n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, + input.options()); + + grad_out = grad_out.view({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_out_channels, + out_h, + out_w}); grad_out.transpose_(1, 2); at::Tensor grad_out_buf = at::zeros_like(grad_out); grad_out_buf.copy_(grad_out); - grad_out_buf = grad_out_buf.view({batch_sz / n_parallel_imgs, n_out_channels, n_parallel_imgs * out_h, out_w}); - grad_out_buf = grad_out_buf.view({grad_out_buf.size(0), n_weight_grps, grad_out_buf.size(1) / n_weight_grps, grad_out_buf.size(2), grad_out_buf.size(3)}); + grad_out_buf = grad_out_buf.view({batch_sz / n_parallel_imgs, + n_out_channels, + n_parallel_imgs * out_h, + out_w}); + grad_out_buf = grad_out_buf.view({grad_out_buf.size(0), + n_weight_grps, + grad_out_buf.size(1) / n_weight_grps, + grad_out_buf.size(2), + grad_out_buf.size(3)}); grad_out.transpose_(1, 2); grad_out = grad_out.view({batch_sz, n_out_channels, out_h, out_w}); - input = input.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); - offset = offset.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); - - grad_weight = grad_weight.view({n_weight_grps, grad_weight.size(0) / n_weight_grps, grad_weight.size(1), grad_weight.size(2), grad_weight.size(3)}); + input = input.view( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + offset = offset.view({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + + grad_weight = grad_weight.view({n_weight_grps, + grad_weight.size(0) / n_weight_grps, + grad_weight.size(1), + grad_weight.size(2), + grad_weight.size(3)}); for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { - deformable_im2col(input[elt], offset[elt], n_in_channels, in_h, - in_w, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, dil_h, - dil_w, out_h, out_w, n_parallel_imgs, n_offset_grps, columns); - - columns = columns.view({n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + deformable_im2col( + input[elt], + offset[elt], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dil_h, + dil_w, + out_h, + out_w, + n_parallel_imgs, + n_offset_grps, + columns); + + columns = columns.view( + {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); for (int g = 0; g < n_weight_grps; g++) { - grad_weight[g] = grad_weight[g] - .flatten(1) - .addmm_(grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0)) - .view_as(grad_weight[g]); + grad_weight[g] = + grad_weight[g] + .flatten(1) + .addmm_( + grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0)) + .view_as(grad_weight[g]); } - columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); } input = input.view({batch_sz, n_in_channels, in_h, in_w}); - offset = offset.view({batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + offset = offset.view( + {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), - grad_weight.size(2), grad_weight.size(3), grad_weight.size(4)}); + grad_weight.size(2), + grad_weight.size(3), + grad_weight.size(4)}); return grad_weight; } - -std::tuple DeformConv2d_backward_cpu( +std::tuple +DeformConv2d_backward_cpu( const at::Tensor& grad_out, const at::Tensor& input, const at::Tensor& weight, @@ -634,22 +946,38 @@ std::tuple DeformConv2d_backward std::pair stride, std::pair pad, std::pair dilation, - int n_weight_grps, int n_offset_grps) { + int n_weight_grps, + int n_offset_grps) { const int batch_sz = input.size(0); - const int n_parallel_imgs = get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); + const int n_parallel_imgs = + get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); auto grad_input_and_offset = deform_conv2d_backward_input_cpu( - input, weight, offset, grad_out, - stride, pad, dilation, - n_weight_grps, n_offset_grps, n_parallel_imgs); + input, + weight, + offset, + grad_out, + stride, + pad, + dilation, + n_weight_grps, + n_offset_grps, + n_parallel_imgs); auto grad_input = std::get<0>(grad_input_and_offset); auto grad_offset = std::get<1>(grad_input_and_offset); auto grad_weight = deform_conv2d_backward_parameters_cpu( - input, weight, offset, grad_out, - stride, pad, dilation, - n_weight_grps, n_offset_grps, n_parallel_imgs); + input, + weight, + offset, + grad_out, + stride, + pad, + dilation, + n_weight_grps, + n_offset_grps, + n_parallel_imgs); auto grad_bias = at::ones_like(bias) * grad_out.sum({0, 2, 3}); diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index 0cb03c7c782..b133c9ff1a3 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -96,7 +96,8 @@ at::Tensor DeformConv2d_forward_cpu( int groups, int deformable_groups); -std::tuple DeformConv2d_backward_cpu( +std::tuple +DeformConv2d_backward_cpu( const at::Tensor& grad_out, const at::Tensor& input, const at::Tensor& weight, diff --git a/torchvision/csrc/cuda/DeformConv_cuda.cu b/torchvision/csrc/cuda/DeformConv_cuda.cu index 225daea8b8c..008385d86a9 100644 --- a/torchvision/csrc/cuda/DeformConv_cuda.cu +++ b/torchvision/csrc/cuda/DeformConv_cuda.cu @@ -1,5 +1,6 @@ /*! - ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + ******************* BEGIN Caffe Copyright Notice and Disclaimer + ***************** * * COPYRIGHT * @@ -23,22 +24,22 @@ * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR - * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE + *FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + *DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + *SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + *CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + *OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * * CONTRIBUTION AGREEMENT * @@ -46,7 +47,8 @@ * or otherwise, the contributor releases their content to the * license and copyright terms herein. * - ***************** END Caffe Copyright Notice and Disclaimer ******************** + ***************** END Caffe Copyright Notice and Disclaimer + ********************* * * Copyright (c) 2018 Microsoft * Licensed under The MIT License [see LICENSE for details] @@ -58,10 +60,11 @@ * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng */ -// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu - -// modified from https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp +// modified from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu +// modified from +// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp #include #include @@ -73,7 +76,6 @@ #include - using namespace at; const int CUDA_NUM_THREADS = 1024; @@ -81,14 +83,17 @@ const int kMaxGridNum = 65535; const int kMaxParallelImgs = 32; -inline int GET_BLOCKS(const int N) -{ +inline int GET_BLOCKS(const int N) { return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); } template -__device__ scalar_t bilinear_interpolate(const scalar_t *in, const int height, const int width, scalar_t h, scalar_t w) -{ +__device__ scalar_t bilinear_interpolate( + const scalar_t* in, + const int height, + const int width, + scalar_t h, + scalar_t w) { if (h <= -1 || height <= h || w <= -1 || width <= w) { return 0; } @@ -122,16 +127,27 @@ __device__ scalar_t bilinear_interpolate(const scalar_t *in, const int height, c } template -__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t* input_ptr, const scalar_t* offset_ptr, - const int height, const int width, const int weight_h, const int weight_w, - const int pad_h, const int pad_w, const int stride_h, const int stride_w, - const int dil_h, const int dil_w, - const int batch_sz, const int n_in_channels, const int n_offset_grps, - const int out_h, const int out_w, - scalar_t* columns_ptr) -{ - CUDA_1D_KERNEL_LOOP(index, n) - { +__global__ void deformable_im2col_gpu_kernel( + const int n, + const scalar_t* input_ptr, + const scalar_t* offset_ptr, + const int height, + const int width, + const int weight_h, + const int weight_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dil_h, + const int dil_w, + const int batch_sz, + const int n_in_channels, + const int n_offset_grps, + const int out_h, + const int out_w, + scalar_t* columns_ptr) { + CUDA_1D_KERNEL_LOOP(index, n) { const int out_x = index % out_w; const int out_y = (index / out_w) % out_h; const int out_b = (index / (out_w * out_h)) % batch_sz; @@ -141,21 +157,23 @@ __global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t* input_ int c_per_offset_grp = n_in_channels / n_offset_grps; const int grp_idx = in_c / c_per_offset_grp; - columns_ptr += (out_c * (batch_sz * out_h * out_w) - + out_b * (out_h * out_w) - + out_y * out_w - + out_x); + columns_ptr += + (out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) + + out_y * out_w + out_x); - input_ptr += (out_b * (n_in_channels * height * width) - + in_c * (height * width)); + input_ptr += + (out_b * (n_in_channels * height * width) + in_c * (height * width)); - offset_ptr += (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h * out_w; + offset_ptr += (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * + out_h * out_w; for (int i = 0; i < weight_h; ++i) { for (int j = 0; j < weight_w; ++j) { const int offset_idx = 2 * (i * weight_w + j); - const scalar_t offset_h = offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x]; - const scalar_t offset_w = offset_ptr[(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x]; + const scalar_t offset_h = + offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x]; + const scalar_t offset_w = offset_ptr + [(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x]; const scalar_t y = (out_y * stride_h - pad_h) + i * dil_h + offset_h; const scalar_t x = (out_x * stride_w - pad_w) + j * dil_w + offset_w; *columns_ptr = bilinear_interpolate(input_ptr, height, width, y, x); @@ -165,39 +183,62 @@ __global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t* input_ } } -void deformable_im2col( - const at::Tensor input, const at::Tensor data_offset, int n_in_channels, - int height, int width, - int weight_h, int weight_w, - int pad_h, int pad_w, - int stride_h, int stride_w, - int dil_h, int dil_w, - int out_h, int out_w, - int parallel_imgs, int deformable_group, at::Tensor data_col) { +static void deformable_im2col( + const at::Tensor input, + const at::Tensor data_offset, + int n_in_channels, + int height, + int width, + int weight_h, + int weight_w, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dil_h, + int dil_w, + int out_h, + int out_w, + int parallel_imgs, + int deformable_group, + at::Tensor data_col) { int num_kernels = n_in_channels * out_h * out_w * parallel_imgs; AT_DISPATCH_FLOATING_TYPES_AND_HALF( input.scalar_type(), "deformable_im2col_gpu", ([&] { - deformable_im2col_gpu_kernel<<>>( + deformable_im2col_gpu_kernel<<< + GET_BLOCKS(num_kernels), + CUDA_NUM_THREADS>>>( num_kernels, input.data_ptr(), data_offset.data_ptr(), - height, width, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, dil_h, dil_w, - parallel_imgs, n_in_channels, deformable_group, - out_h, out_w, + height, + width, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dil_h, + dil_w, + parallel_imgs, + n_in_channels, + deformable_group, + out_h, + out_w, data_col.data_ptr()); })); cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) - { + if (err != cudaSuccess) { printf("error in deformable_im2col: %s\n", cudaGetErrorString(err)); } } -int get_greatest_divisor_below_bound(int n, int bound) { - for(int k = bound; k > 1; --k) { - if(n % k == 0) { +static int get_greatest_divisor_below_bound(int n, int bound) { + for (int k = bound; k > 1; --k) { + if (n % k == 0) { return k; } } @@ -212,7 +253,8 @@ at::Tensor DeformConv2d_forward_cuda( std::pair stride, std::pair pad, std::pair dilation, - int n_weight_grps, int n_offset_grps) { + int n_weight_grps, + int n_offset_grps) { at::Tensor input = input_param; at::Tensor weight = weight_param; at::Tensor offset = offset_param; @@ -232,7 +274,8 @@ at::Tensor DeformConv2d_forward_cuda( int in_h = input.size(2); int in_w = input.size(3); - int n_parallel_imgs = get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); + int n_parallel_imgs = + get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); int out_channels = weight.size(0); int weight_h = weight.size(2); @@ -249,12 +292,21 @@ at::Tensor DeformConv2d_forward_cuda( int ker_h = dil_h * (weight_h - 1) + 1; int ker_w = dil_w * (weight_w - 1) + 1; - int out_h = ((in_h + 2*pad_h - ker_h) / stride_h) + 1; - int out_w = ((in_w + 2*pad_w - ker_w) / stride_w) + 1; - - - TORCH_CHECK(weight_h > 0 && weight_w > 0, "weight_h: ", weight_h, " weight_w: ", weight_w); - TORCH_CHECK(stride_h > 0 && stride_w > 0, "stride_h: ", stride_h, " stride_w: ", stride_w); + int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; + int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1; + + TORCH_CHECK( + weight_h > 0 && weight_w > 0, + "weight_h: ", + weight_h, + " weight_w: ", + weight_w); + TORCH_CHECK( + stride_h > 0 && stride_w > 0, + "stride_h: ", + stride_h, + " stride_w: ", + stride_w); TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w); TORCH_CHECK(dil_h > 0 && dil_w > 0, "dil_h: ", dil_h, " dil_w: ", dil_w); @@ -262,43 +314,107 @@ at::Tensor DeformConv2d_forward_cuda( TORCH_CHECK(weight.size(0) % n_weight_grps == 0); TORCH_CHECK(input.size(1) % n_offset_grps == 0); - TORCH_CHECK((offset.size(0) == input.size(0)), "invalid batch size of offset"); - TORCH_CHECK((offset.size(1) == n_offset_grps * 2 * weight_h * weight_w), - "got: ", offset.size(1), " expected: ", n_offset_grps * 2 * weight_h * weight_w); - TORCH_CHECK((offset.size(2) == out_h && offset.size(3) == out_w), - "offset output dims: (", offset.size(2), ", ", offset.size(3), ") - ", - "computed output dims: (", out_h, ", ", out_w, ")"); - TORCH_CHECK(out_h > 0 && out_w > 0, "Calculated output size too small - out_h: ", out_h, " out_w: ", out_w); - + TORCH_CHECK( + (offset.size(0) == input.size(0)), "invalid batch size of offset"); + TORCH_CHECK( + (offset.size(1) == n_offset_grps * 2 * weight_h * weight_w), + "got: ", + offset.size(1), + " expected: ", + n_offset_grps * 2 * weight_h * weight_w); + TORCH_CHECK( + (offset.size(2) == out_h && offset.size(3) == out_w), + "offset output dims: (", + offset.size(2), + ", ", + offset.size(3), + ") - ", + "computed output dims: (", + out_h, + ", ", + out_w, + ")"); + TORCH_CHECK( + out_h > 0 && out_w > 0, + "Calculated output size too small - out_h: ", + out_h, + " out_w: ", + out_w); auto out = at::zeros({batch_sz, out_channels, out_h, out_w}, input.options()); // Separate batches into blocks - out = out.view({batch_sz / n_parallel_imgs, n_parallel_imgs, out_channels, out_h, out_w}); - input = input.view({batch_sz / n_parallel_imgs, n_parallel_imgs, in_channels, in_h, in_w}); - offset = offset.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); - at::Tensor out_buf = at::zeros({batch_sz / n_parallel_imgs, out_channels, n_parallel_imgs * out_h, out_w}, out.options()); + out = out.view({batch_sz / n_parallel_imgs, + n_parallel_imgs, + out_channels, + out_h, + out_w}); + input = input.view( + {batch_sz / n_parallel_imgs, n_parallel_imgs, in_channels, in_h, in_w}); + offset = offset.view({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + at::Tensor out_buf = at::zeros( + {batch_sz / n_parallel_imgs, + out_channels, + n_parallel_imgs * out_h, + out_w}, + out.options()); // Separate channels into convolution groups - out_buf = out_buf.view({out_buf.size(0), n_weight_grps, out_buf.size(1) / n_weight_grps, out_buf.size(2), out_buf.size(3)}); - weight = weight.view({n_weight_grps, weight.size(0) / n_weight_grps, weight.size(1), weight.size(2), weight.size(3)}); + out_buf = out_buf.view({out_buf.size(0), + n_weight_grps, + out_buf.size(1) / n_weight_grps, + out_buf.size(2), + out_buf.size(3)}); + weight = weight.view({n_weight_grps, + weight.size(0) / n_weight_grps, + weight.size(1), + weight.size(2), + weight.size(3)}); // Sample points and perform convolution - auto columns = at::zeros({in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w}, input.options()); + auto columns = at::zeros( + {in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w}, + input.options()); for (int b = 0; b < batch_sz / n_parallel_imgs; b++) { - deformable_im2col(input[b], offset[b], in_channels, in_h, - in_w, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, dil_h, - dil_w, out_h, out_w, n_parallel_imgs, n_offset_grps, columns); - - columns = columns.view({n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + deformable_im2col( + input[b], + offset[b], + in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dil_h, + dil_w, + out_h, + out_w, + n_parallel_imgs, + n_offset_grps, + columns); + + columns = columns.view( + {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); for (int g = 0; g < n_weight_grps; g++) { - out_buf[b][g] = out_buf[b][g].flatten(1) - .addmm_(weight[g].flatten(1), columns[g]) - .view_as(out_buf[b][g]); + out_buf[b][g] = out_buf[b][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(out_buf[b][g]); } } - out_buf = out_buf.view({batch_sz / n_parallel_imgs, out_channels, n_parallel_imgs, out_h, out_w}); + out_buf = out_buf.view({batch_sz / n_parallel_imgs, + out_channels, + n_parallel_imgs, + out_h, + out_w}); out_buf.transpose_(1, 2); out.copy_(out_buf); out = out.view({batch_sz, out_channels, out_h, out_w}); @@ -306,21 +422,28 @@ at::Tensor DeformConv2d_forward_cuda( return out + bias.view({1, out_channels, 1, 1}); } - template __global__ void deformable_col2im_gpu_kernel( - const int n, const scalar_t *col, const scalar_t *offset_ptr, - const int channels, const int height, const int width, - const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int batch_sz, const int n_offset_grps, - const int out_h, const int out_w, - scalar_t *grad_im) -{ - CUDA_1D_KERNEL_LOOP(index, n) - { + const int n, + const scalar_t* col, + const scalar_t* offset_ptr, + const int channels, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int batch_sz, + const int n_offset_grps, + const int out_h, + const int out_w, + scalar_t* grad_im) { + CUDA_1D_KERNEL_LOOP(index, n) { const int out_x = index % out_w; const int out_y = (index / out_w) % out_h; const int b = (index / (out_w * out_h)) % batch_sz; @@ -331,9 +454,12 @@ __global__ void deformable_col2im_gpu_kernel( int c_per_offset_grp = channels / n_offset_grps; const int offset_grp = c / c_per_offset_grp; - offset_ptr += (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * out_h * out_w; - const int offset_h_ptr = ((2 * (i * kernel_w + j)) * out_h + out_y) * out_w + out_x; - const int offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * out_h + out_y) * out_w + out_x; + offset_ptr += (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * + out_h * out_w; + const int offset_h_ptr = + ((2 * (i * kernel_w + j)) * out_h + out_y) * out_w + out_x; + const int offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * out_h + out_y) * out_w + out_x; const scalar_t offset_h = offset_ptr[offset_h_ptr]; const scalar_t offset_w = offset_ptr[offset_w_ptr]; const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; @@ -343,10 +469,8 @@ __global__ void deformable_col2im_gpu_kernel( for (int dx = -1; dx <= 1; dx++) { int yp = int(y) + dy; int xp = int(x) + dx; - if (0 <= yp && yp < height && - 0 <= xp && xp < width && - abs(y - yp) < 1 && - abs(x - xp) < 1) { + if (0 <= yp && yp < height && 0 <= xp && xp < width && + abs(y - yp) < 1 && abs(x - xp) < 1) { int grad_pos = ((b * channels + c) * height + yp) * width + xp; scalar_t weight = (1 - abs(y - yp)) * (1 - abs(x - xp)); atomicAdd(grad_im + grad_pos, weight * col[index]); @@ -356,40 +480,70 @@ __global__ void deformable_col2im_gpu_kernel( } } -void compute_grad_input( - const at::Tensor columns, const at::Tensor offset, const int channels, - const int height, const int width, const int weight_h, - const int weight_w, const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int parallel_imgs, const int n_offset_grps, +static void compute_grad_input( + const at::Tensor columns, + const at::Tensor offset, + const int channels, + const int height, + const int width, + const int weight_h, + const int weight_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int parallel_imgs, + const int n_offset_grps, at::Tensor grad_im) { - int out_h = (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; - int out_w = (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; - int num_kernels = channels * weight_h * weight_w * out_h * out_w * parallel_imgs; + int out_h = + (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; + int out_w = + (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; + int num_kernels = + channels * weight_h * weight_w * out_h * out_w * parallel_imgs; AT_DISPATCH_FLOATING_TYPES_AND_HALF( columns.scalar_type(), "deformable_col2im_gpu", ([&] { - deformable_col2im_gpu_kernel<<>>( + deformable_col2im_gpu_kernel<<< + GET_BLOCKS(num_kernels), + CUDA_NUM_THREADS>>>( num_kernels, columns.data_ptr(), offset.data_ptr(), - channels, height, width, - weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, - parallel_imgs, n_offset_grps, out_h, out_w, + channels, + height, + width, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + parallel_imgs, + n_offset_grps, + out_h, + out_w, grad_im.data_ptr()); })); cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) - { + if (err != cudaSuccess) { printf("error in compute_grad_input: %s\n", cudaGetErrorString(err)); } } - template -__device__ scalar_t get_coordinate_weight(const scalar_t *im_data, const int height, const int width, scalar_t y, scalar_t x, bool is_y_direction) { +__device__ scalar_t get_coordinate_weight( + const scalar_t* im_data, + const int height, + const int width, + scalar_t y, + scalar_t x, + bool is_y_direction) { int y_l = floor(y); int x_l = floor(x); int y_h = y_l + 1; @@ -415,20 +569,30 @@ __device__ scalar_t get_coordinate_weight(const scalar_t *im_data, const int hei } } - template -__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *col_ptr, - const scalar_t *im_ptr, const scalar_t *offset_ptr, - const int channels, const int height, const int width, - const int weight_h, const int weight_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int batch_sz, const int offset_channels, const int n_offset_grps, - const int out_h, const int out_w, scalar_t *grad_offset) -{ - CUDA_1D_KERNEL_LOOP(index, n) - { +__global__ void deformable_col2im_coord_gpu_kernel( + const int n, + const scalar_t* col_ptr, + const scalar_t* im_ptr, + const scalar_t* offset_ptr, + const int channels, + const int height, + const int width, + const int weight_h, + const int weight_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int batch_sz, + const int offset_channels, + const int n_offset_grps, + const int out_h, + const int out_w, + scalar_t* grad_offset) { + CUDA_1D_KERNEL_LOOP(index, n) { scalar_t val = 0; int w = index % out_w; int h = (index / out_w) % out_h; @@ -440,9 +604,12 @@ __global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t * int c_per_offset_grp = channels / n_offset_grps; - col_ptr += offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * out_w * out_h; - im_ptr += (b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width; - offset_ptr += (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * out_h * out_w; + col_ptr += offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * + out_w * out_h; + im_ptr += + (b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width; + offset_ptr += (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * + out_h * out_w; const int offset_c = c - offset_grp * 2 * weight_h * weight_w; const int is_y_direction = offset_c % 2 == 0; @@ -456,15 +623,18 @@ __global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t * int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w; int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h; - const int offset_h_ptr = (((2 * (i * weight_w + j)) * out_h + out_y) * out_w + out_x); - const int offset_w_ptr = (((2 * (i * weight_w + j) + 1) * out_h + out_y) * out_w + out_x); + const int offset_h_ptr = + (((2 * (i * weight_w + j)) * out_h + out_y) * out_w + out_x); + const int offset_w_ptr = + (((2 * (i * weight_w + j) + 1) * out_h + out_y) * out_w + out_x); const scalar_t offset_h = offset_ptr[offset_h_ptr]; const scalar_t offset_w = offset_ptr[offset_w_ptr]; scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; - const scalar_t weight = get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction); + const scalar_t weight = + get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction); val += weight * col_ptr[col_pos]; im_ptr += height * width; } @@ -473,47 +643,76 @@ __global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t * } } - -void compute_grad_offset( - const at::Tensor columns, const at::Tensor input, const at::Tensor offset, - const int channels, const int height, const int width, const int weight_h, - const int weight_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, const int dilation_h, const int dilation_w, - const int parallel_imgs, const int n_offset_grps, at::Tensor grad_offset) { - int out_h = (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; - int out_w = (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; - int num_kernels = out_h * out_w * 2 * weight_h * weight_w * n_offset_grps * parallel_imgs; +static void compute_grad_offset( + const at::Tensor columns, + const at::Tensor input, + const at::Tensor offset, + const int channels, + const int height, + const int width, + const int weight_h, + const int weight_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int parallel_imgs, + const int n_offset_grps, + at::Tensor grad_offset) { + int out_h = + (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; + int out_w = + (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; + int num_kernels = + out_h * out_w * 2 * weight_h * weight_w * n_offset_grps * parallel_imgs; AT_DISPATCH_FLOATING_TYPES_AND_HALF( columns.scalar_type(), "deformable_col2im_coord_gpu", ([&] { - deformable_col2im_coord_gpu_kernel<<>>( + deformable_col2im_coord_gpu_kernel<<< + GET_BLOCKS(num_kernels), + CUDA_NUM_THREADS>>>( num_kernels, columns.data_ptr(), input.data_ptr(), offset.data_ptr(), - channels, height, width, weight_h, - weight_w, pad_h, pad_w, stride_h, stride_w, - dilation_h, dilation_w, - parallel_imgs, 2 * weight_h * weight_w * n_offset_grps, n_offset_grps, - out_h, out_w, + channels, + height, + width, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + parallel_imgs, + 2 * weight_h * weight_w * n_offset_grps, + n_offset_grps, + out_h, + out_w, grad_offset.data_ptr()); })); cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) - { + if (err != cudaSuccess) { printf("error in compute_grad_offset: %s\n", cudaGetErrorString(err)); } } - -std::tuple deform_conv_backward_input_cuda( - at::Tensor input, at::Tensor weight, at::Tensor offset, +static std::tuple deform_conv_backward_input_cuda( + at::Tensor input, + at::Tensor weight, + at::Tensor offset, at::Tensor grad_out, std::pair stride, std::pair pad, std::pair dilation, - int n_weight_grps, int n_offset_grps, int n_parallel_imgs) { + int n_weight_grps, + int n_offset_grps, + int n_parallel_imgs) { at::DeviceGuard guard(input.device()); int batch_sz = input.size(0); @@ -541,62 +740,122 @@ std::tuple deform_conv_backward_input_cuda( auto grad_input = at::zeros_like(input); auto grad_offset = at::zeros_like(offset); - auto columns = at::zeros({n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, input.options()); + auto columns = at::zeros( + {n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, + input.options()); // Separate into blocks - grad_input = grad_input.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); - input = input.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); - grad_offset = grad_offset.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); - offset = offset.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); - - grad_out = grad_out.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_out_channels, out_h, out_w}); + grad_input = grad_input.view( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + input = input.view( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + grad_offset = grad_offset.view({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + offset = offset.view({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + + grad_out = grad_out.view({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_out_channels, + out_h, + out_w}); grad_out.transpose_(1, 2); - grad_out = grad_out.view( - {grad_out.size(0), n_weight_grps, grad_out.size(1) / n_weight_grps, - grad_out.size(2), grad_out.size(3), grad_out.size(4)}); + grad_out = grad_out.view({grad_out.size(0), + n_weight_grps, + grad_out.size(1) / n_weight_grps, + grad_out.size(2), + grad_out.size(3), + grad_out.size(4)}); for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { // Separate into weight groups - columns = columns.view({n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); - weight = weight.view({n_weight_grps, weight.size(0) / n_weight_grps, weight.size(1), weight.size(2), weight.size(3)}); + columns = columns.view( + {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + weight = weight.view({n_weight_grps, + weight.size(0) / n_weight_grps, + weight.size(1), + weight.size(2), + weight.size(3)}); for (int g = 0; g < n_weight_grps; g++) { - columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1)); + columns[g] = columns[g].addmm_( + weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1)); } - columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)}); - - compute_grad_offset(columns, input[elt], offset[elt], n_in_channels, - in_h, in_w, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, - dil_h, dil_w, n_parallel_imgs, n_offset_grps, - grad_offset[elt]); - - compute_grad_input(columns, offset[elt], n_in_channels, in_h, - in_w, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, dil_h, - dil_w, n_parallel_imgs, n_offset_grps, grad_input[elt]); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + + compute_grad_offset( + columns, + input[elt], + offset[elt], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dil_h, + dil_w, + n_parallel_imgs, + n_offset_grps, + grad_offset[elt]); + + compute_grad_input( + columns, + offset[elt], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dil_h, + dil_w, + n_parallel_imgs, + n_offset_grps, + grad_input[elt]); } - grad_out = grad_out.view( - {grad_out.size(0), grad_out.size(1) * grad_out.size(2), - grad_out.size(3), grad_out.size(4), grad_out.size(5)}); + grad_out = grad_out.view({grad_out.size(0), + grad_out.size(1) * grad_out.size(2), + grad_out.size(3), + grad_out.size(4), + grad_out.size(5)}); grad_out.transpose_(1, 2); grad_out = grad_out.view({batch_sz, n_out_channels, out_h, out_w}); grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w}); input = input.view({batch_sz, n_in_channels, in_h, in_w}); - grad_offset = grad_offset.view({batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); - offset = offset.view({batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + grad_offset = grad_offset.view( + {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + offset = offset.view( + {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); return {grad_input, grad_offset}; } - - -at::Tensor deform_conv_backward_parameters_cuda( - at::Tensor input, at::Tensor weight, at::Tensor offset, +static at::Tensor deform_conv_backward_parameters_cuda( + at::Tensor input, + at::Tensor weight, + at::Tensor offset, at::Tensor grad_out, std::pair stride, std::pair pad, std::pair dilation, - int n_weight_grps, int n_offset_grps, int n_parallel_imgs) { + int n_weight_grps, + int n_offset_grps, + int n_parallel_imgs) { at::DeviceGuard guard(input.device()); int batch_sz = input.size(0); @@ -622,50 +881,95 @@ at::Tensor deform_conv_backward_parameters_cuda( long out_h = grad_out.size(2); long out_w = grad_out.size(3); - auto grad_weight = at::zeros_like(weight);; - auto columns = at::zeros({n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, input.options()); - - grad_out = grad_out.view({batch_sz / n_parallel_imgs, n_parallel_imgs, - n_out_channels, out_h, out_w}); + auto grad_weight = at::zeros_like(weight); + ; + auto columns = at::zeros( + {n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, + input.options()); + + grad_out = grad_out.view({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_out_channels, + out_h, + out_w}); grad_out.transpose_(1, 2); at::Tensor grad_out_buf = at::zeros_like(grad_out); grad_out_buf.copy_(grad_out); - grad_out_buf = grad_out_buf.view({batch_sz / n_parallel_imgs, n_out_channels, n_parallel_imgs * out_h, out_w}); - grad_out_buf = grad_out_buf.view({grad_out_buf.size(0), n_weight_grps, grad_out_buf.size(1) / n_weight_grps, grad_out_buf.size(2), grad_out_buf.size(3)}); + grad_out_buf = grad_out_buf.view({batch_sz / n_parallel_imgs, + n_out_channels, + n_parallel_imgs * out_h, + out_w}); + grad_out_buf = grad_out_buf.view({grad_out_buf.size(0), + n_weight_grps, + grad_out_buf.size(1) / n_weight_grps, + grad_out_buf.size(2), + grad_out_buf.size(3)}); grad_out.transpose_(1, 2); grad_out = grad_out.view({batch_sz, n_out_channels, out_h, out_w}); - input = input.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); - offset = offset.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); - - grad_weight = grad_weight.view({n_weight_grps, grad_weight.size(0) / n_weight_grps, grad_weight.size(1), grad_weight.size(2), grad_weight.size(3)}); + input = input.view( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + offset = offset.view({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + + grad_weight = grad_weight.view({n_weight_grps, + grad_weight.size(0) / n_weight_grps, + grad_weight.size(1), + grad_weight.size(2), + grad_weight.size(3)}); for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { - deformable_im2col(input[elt], offset[elt], n_in_channels, in_h, - in_w, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, dil_h, - dil_w, out_h, out_w, n_parallel_imgs, n_offset_grps, columns); - - columns = columns.view({n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + deformable_im2col( + input[elt], + offset[elt], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dil_h, + dil_w, + out_h, + out_w, + n_parallel_imgs, + n_offset_grps, + columns); + + columns = columns.view( + {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); for (int g = 0; g < n_weight_grps; g++) { - grad_weight[g] = grad_weight[g] - .flatten(1) - .addmm_(grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0)) - .view_as(grad_weight[g]); + grad_weight[g] = + grad_weight[g] + .flatten(1) + .addmm_( + grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0)) + .view_as(grad_weight[g]); } - columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); } input = input.view({batch_sz, n_in_channels, in_h, in_w}); - offset = offset.view({batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + offset = offset.view( + {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), - grad_weight.size(2), grad_weight.size(3), grad_weight.size(4)}); + grad_weight.size(2), + grad_weight.size(3), + grad_weight.size(4)}); return grad_weight; } - -std::tuple DeformConv2d_backward_cuda( +std::tuple +DeformConv2d_backward_cuda( const at::Tensor& grad_out, const at::Tensor& input, const at::Tensor& weight, @@ -674,22 +978,38 @@ std::tuple DeformConv2d_backward std::pair stride, std::pair pad, std::pair dilation, - int n_weight_grps, int n_offset_grps) { + int n_weight_grps, + int n_offset_grps) { const int batch_sz = input.size(0); - const int n_parallel_imgs = get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); + const int n_parallel_imgs = + get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); auto grad_input_and_offset = deform_conv_backward_input_cuda( - input, weight, offset, grad_out, - stride, pad, dilation, - n_weight_grps, n_offset_grps, n_parallel_imgs); + input, + weight, + offset, + grad_out, + stride, + pad, + dilation, + n_weight_grps, + n_offset_grps, + n_parallel_imgs); auto grad_input = std::get<0>(grad_input_and_offset); auto grad_offset = std::get<1>(grad_input_and_offset); auto grad_weight = deform_conv_backward_parameters_cuda( - input, weight, offset, grad_out, - stride, pad, dilation, - n_weight_grps, n_offset_grps, n_parallel_imgs); + input, + weight, + offset, + grad_out, + stride, + pad, + dilation, + n_weight_grps, + n_offset_grps, + n_parallel_imgs); auto value = grad_out.sum({0, 2, 3}); auto grad_bias = at::ones_like(bias) * value; diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index fe7655a52aa..36e6b3d090b 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -97,7 +97,8 @@ at::Tensor DeformConv2d_forward_cuda( int groups, int deformable_groups); -std::tuple DeformConv2d_backward_cuda( +std::tuple +DeformConv2d_backward_cuda( const at::Tensor& grad_out, const at::Tensor& input, const at::Tensor& weight, diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index 12cc7c87675..b761dc88710 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -5,13 +5,13 @@ #include #endif +#include "DeformConv.h" #include "PSROIAlign.h" #include "PSROIPool.h" #include "ROIAlign.h" #include "ROIPool.h" #include "empty_tensor_op.h" #include "nms.h" -#include "DeformConv.h" // If we are in a Windows environment, we need to define // initialization functions for the _custom_ops extension From 65749cc5c238eef7f775c9383e2806515121eab3 Mon Sep 17 00:00:00 2001 From: Pedro Freire Date: Wed, 20 Nov 2019 11:20:34 +0000 Subject: [PATCH 4/9] Import Optional type annotation --- torchvision/ops/deform_conv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/deform_conv.py b/torchvision/ops/deform_conv.py index bb72a390f87..5b3cf27487e 100644 --- a/torchvision/ops/deform_conv.py +++ b/torchvision/ops/deform_conv.py @@ -5,7 +5,7 @@ from torch.nn import init from torch.nn.parameter import Parameter from torch.nn.modules.utils import _pair -from torch.jit.annotations import Tuple +from torch.jit.annotations import Optional, Tuple def deform_conv2d(input, weight, offset, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1)): From 4e2dadf5dcda5d00dc94b75e526fa405b424da3e Mon Sep 17 00:00:00 2001 From: Pedro Freire Date: Fri, 22 Nov 2019 15:28:52 +0000 Subject: [PATCH 5/9] Remove offset param from DeformConv2d module - We pass the offset in the forward of DeformConv2d, instead of having an internal parameter. This adds some complexity to creating the module (e.g. now you have to worry about the output size, to create the offset), but it gives more flexibility. - We also use make_tuple for tuple creation, in an attempt to fix error w/ older compilers. --- test/test_ops.py | 13 ++++++------ torchvision/csrc/cpu/DeformConv_cpu.cpp | 5 +++-- torchvision/csrc/cuda/DeformConv_cuda.cu | 4 ++-- torchvision/ops/deform_conv.py | 27 ++++++++++-------------- 4 files changed, 22 insertions(+), 27 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 2deef26b861..e3477a0332c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -462,7 +462,7 @@ def get_fn_args(self, device, contiguous): return x, weight, offset, bias, stride, pad, dilation def _test_forward(self, device, contiguous): - x, _, _, _, stride, padding, dilation = self.get_fn_args(device, contiguous) + x, _, offset, _, stride, padding, dilation = self.get_fn_args(device, contiguous) in_channels = 6 out_channels = 2 kernel_size = (3, 2) @@ -470,13 +470,12 @@ def _test_forward(self, device, contiguous): offset_groups = 3 layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, - dilation=dilation, groups=groups, offset_groups=offset_groups) - layer.offset_conv.weight.data = torch.randn_like(layer.offset_conv.weight.data) - res = layer(x) + dilation=dilation, groups=groups, offset_groups=offset_groups).to(device=x.device, + dtype=x.dtype) + res = layer(x, offset) - weight = layer.weight.data.to(device=x.device, dtype=x.dtype) - offset = layer.offset_conv.to(device=x.device, dtype=x.dtype)(x) - bias = layer.bias.data.to(device=x.device, dtype=x.dtype) + weight = layer.weight.data + bias = layer.bias.data expected = self.expected_fn(x, weight, offset, bias, stride=stride, padding=padding, dilation=dilation) self.assertTrue(torch.allclose(res, expected), '\nres:\n{}\nexpected:\n{}'.format(res, expected)) diff --git a/torchvision/csrc/cpu/DeformConv_cpu.cpp b/torchvision/csrc/cpu/DeformConv_cpu.cpp index 94590b341f6..25dcca73fc2 100644 --- a/torchvision/csrc/cpu/DeformConv_cpu.cpp +++ b/torchvision/csrc/cpu/DeformConv_cpu.cpp @@ -71,6 +71,7 @@ #include #include +#include using namespace at; @@ -812,7 +813,7 @@ static std::tuple deform_conv2d_backward_input_cpu( offset = offset.view( {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); - return {grad_input, grad_offset}; + return std::make_tuple(grad_input, grad_offset); } static at::Tensor deform_conv2d_backward_parameters_cpu( @@ -981,5 +982,5 @@ DeformConv2d_backward_cpu( auto grad_bias = at::ones_like(bias) * grad_out.sum({0, 2, 3}); - return {grad_input, grad_weight, grad_offset, grad_bias}; + return std::make_tuple(grad_input, grad_weight, grad_offset, grad_bias); } diff --git a/torchvision/csrc/cuda/DeformConv_cuda.cu b/torchvision/csrc/cuda/DeformConv_cuda.cu index 008385d86a9..d3dedb3dd60 100644 --- a/torchvision/csrc/cuda/DeformConv_cuda.cu +++ b/torchvision/csrc/cuda/DeformConv_cuda.cu @@ -842,7 +842,7 @@ static std::tuple deform_conv_backward_input_cuda( offset = offset.view( {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); - return {grad_input, grad_offset}; + return std::make_tuple(grad_input, grad_offset); } static at::Tensor deform_conv_backward_parameters_cuda( @@ -1014,5 +1014,5 @@ DeformConv2d_backward_cuda( auto value = grad_out.sum({0, 2, 3}); auto grad_bias = at::ones_like(bias) * value; - return {grad_input, grad_weight, grad_offset, grad_bias}; + return std::make_tuple(grad_input, grad_weight, grad_offset, grad_bias); } diff --git a/torchvision/ops/deform_conv.py b/torchvision/ops/deform_conv.py index 5b3cf27487e..3c4eec144eb 100644 --- a/torchvision/ops/deform_conv.py +++ b/torchvision/ops/deform_conv.py @@ -81,14 +81,6 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, self.weight = Parameter(torch.empty(out_channels, in_channels // groups, kernel_size[0], kernel_size[1])) - self.offset_conv = nn.Conv2d( - self.in_channels, - offset_groups * 2 * self.kernel_size[0] * self.kernel_size[1], - kernel_size=self.kernel_size, - stride=self.stride, - padding=self.padding, - dilation=self.dilation) - if bias: self.bias = Parameter(torch.empty(out_channels)) else: @@ -98,19 +90,22 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, def reset_parameters(self): init.kaiming_uniform_(self.weight, a=math.sqrt(5)) - init.zeros_(self.offset_conv.weight) - init.zeros_(self.offset_conv.bias) if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) - def forward(self, input): - offset = self.offset_conv.to(device=input.device, dtype=input.dtype)(input) - weight = self.weight.to(device=input.device, dtype=input.dtype) - bias = self.bias.to(device=input.device, dtype=input.dtype) if self.bias is not None else self.bias - - return deform_conv2d(input, weight, offset, bias, stride=self.stride, + def forward(self, input, offset): + """ + Arguments: + input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor + weight (Tensor[out_channels, in_channels // groups, kernel_height, kernel_width]): + convolution weights, split into groups of size (in_channels // groups) + offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width, + out_height, out_width]): offsets to be applied for each position in the + convolution kernel. + """ + return deform_conv2d(input, self.weight, offset, self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation) def __repr__(self): From a6092beba47954b2daffd36875f1dd1471f8b65a Mon Sep 17 00:00:00 2001 From: Pedro Freire Date: Fri, 29 Nov 2019 00:25:07 +0100 Subject: [PATCH 6/9] Replace abs by std::abs Old gcc versions were giving wrong results here, because they would resolve abs as int -> int, thus causing undesired truncation. Replacing abs by std::abs should allow for correct overloading of abs as float -> float. --- torchvision/csrc/cpu/DeformConv_cpu.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/csrc/cpu/DeformConv_cpu.cpp b/torchvision/csrc/cpu/DeformConv_cpu.cpp index 25dcca73fc2..4ccaa8a02f5 100644 --- a/torchvision/csrc/cpu/DeformConv_cpu.cpp +++ b/torchvision/csrc/cpu/DeformConv_cpu.cpp @@ -70,6 +70,7 @@ #include #include +#include #include #include @@ -455,9 +456,9 @@ static void deformable_col2im_kernel( int yp = int(y) + dy; int xp = int(x) + dx; if (0 <= yp && yp < height && 0 <= xp && xp < width && - abs(y - yp) < 1 && abs(x - xp) < 1) { + std::abs(y - yp) < 1 && std::abs(x - xp) < 1) { int grad_pos = ((b * channels + c) * height + yp) * width + xp; - scalar_t weight = (1 - abs(y - yp)) * (1 - abs(x - xp)); + scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp)); grad_im[grad_pos] += weight * col[index]; } } From 3d90f2c5c13cea25a9d60a26a2a438a690a706d0 Mon Sep 17 00:00:00 2001 From: Pedro Freire Date: Fri, 29 Nov 2019 00:33:16 +0100 Subject: [PATCH 7/9] Reorder declarations for clarity --- torchvision/csrc/cpu/DeformConv_cpu.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torchvision/csrc/cpu/DeformConv_cpu.cpp b/torchvision/csrc/cpu/DeformConv_cpu.cpp index 4ccaa8a02f5..0095f25328e 100644 --- a/torchvision/csrc/cpu/DeformConv_cpu.cpp +++ b/torchvision/csrc/cpu/DeformConv_cpu.cpp @@ -428,6 +428,9 @@ static void deformable_col2im_kernel( const int out_w, scalar_t* grad_im) { for (int index = 0; index != n; ++index) { + const int out_x = index % out_w; + const int out_y = (index / out_w) % out_h; + const int b = (index / (out_w * out_h)) % batch_sz; const int j = (index / (out_w * out_h * batch_sz)) % kernel_w; const int i = (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h; const int c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h); @@ -435,10 +438,6 @@ static void deformable_col2im_kernel( int c_per_offset_grp = channels / n_offset_grps; const int offset_grp = c / c_per_offset_grp; - int out_x = index % out_w; - int out_y = (index / out_w) % out_h; - int b = (index / (out_w * out_h)) % batch_sz; - auto offset_ptr = offset + (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * out_h * out_w; From 929eecc479661b38269fc404fd8b70fabe4bc41e Mon Sep 17 00:00:00 2001 From: Pedro Freire Date: Fri, 29 Nov 2019 00:36:25 +0100 Subject: [PATCH 8/9] Reorder weight and offset args in deform_conv2d We place offset arg before the weight arg, to be more consistent with DeformConv2d.forward(input, offset) --- test/test_ops.py | 14 +++++++------- torchvision/ops/deform_conv.py | 8 ++++---- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index e3477a0332c..d87c53b4355 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -483,18 +483,18 @@ def _test_forward(self, device, contiguous): def _test_backward(self, device, contiguous): x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous) - def func(x_, weight_, offset_, bias_): - return ops.deform_conv2d(x_, weight_, offset_, bias_, stride=stride, padding=padding, dilation=dilation) + def func(x_, offset_, weight_, bias_): + return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation) - gradcheck(func, (x, weight, offset, bias), nondet_tol=1e-5) + gradcheck(func, (x, offset, weight, bias), nondet_tol=1e-5) @torch.jit.script - def script_func(x_, weight_, offset_, bias_, stride_, pad_, dilation_): + def script_func(x_, offset_, weight_, bias_, stride_, pad_, dilation_): # type: (Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor - return ops.deform_conv2d(x_, weight_, offset_, bias_, stride=stride_, padding=pad_, dilation=dilation_) + return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_) - gradcheck(lambda z, wei, off, bi: script_func(z, wei, off, bi, stride, padding, dilation), - (x, weight, offset, bias), nondet_tol=1e-5) + gradcheck(lambda z, off, wei, bi: script_func(z, off, wei, bi, stride, padding, dilation), + (x, offset, weight, bias), nondet_tol=1e-5) if __name__ == '__main__': diff --git a/torchvision/ops/deform_conv.py b/torchvision/ops/deform_conv.py index 3c4eec144eb..98672dbcfde 100644 --- a/torchvision/ops/deform_conv.py +++ b/torchvision/ops/deform_conv.py @@ -8,18 +8,18 @@ from torch.jit.annotations import Optional, Tuple -def deform_conv2d(input, weight, offset, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1)): +def deform_conv2d(input, offset, weight, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1)): # type: (Tensor, Tensor, Tensor, Optional[Tensor], Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor """ Performs Deformable Convolution, described in Deformable Convolutional Networks Arguments: input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor - weight (Tensor[out_channels, in_channels // groups, kernel_height, kernel_width]): - convolution weights, split into groups of size (in_channels // groups) offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width, out_height, out_width]): offsets to be applied for each position in the convolution kernel. + weight (Tensor[out_channels, in_channels // groups, kernel_height, kernel_width]): + convolution weights, split into groups of size (in_channels // groups) bias (Tensor[out_channels]): optional bias of shape (out_channels,). Default: None stride (int or Tuple[int, int]): distance between convolution centers. Default: 1 padding (int or Tuple[int, int]): height/width of padding of zeroes around @@ -105,7 +105,7 @@ def forward(self, input, offset): out_height, out_width]): offsets to be applied for each position in the convolution kernel. """ - return deform_conv2d(input, self.weight, offset, self.bias, stride=self.stride, + return deform_conv2d(input, offset, self.weight, self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation) def __repr__(self): From 5aebed3ce472b89fc1eb466195df79d5063b0175 Mon Sep 17 00:00:00 2001 From: Pedro Freire Date: Fri, 29 Nov 2019 00:51:10 +0100 Subject: [PATCH 9/9] Replace abs by std::abs in DeformConv_cuda --- torchvision/csrc/cuda/DeformConv_cuda.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchvision/csrc/cuda/DeformConv_cuda.cu b/torchvision/csrc/cuda/DeformConv_cuda.cu index d3dedb3dd60..8c472df3b26 100644 --- a/torchvision/csrc/cuda/DeformConv_cuda.cu +++ b/torchvision/csrc/cuda/DeformConv_cuda.cu @@ -74,7 +74,9 @@ #include "cuda_helpers.h" +#include #include +#include using namespace at; @@ -470,9 +472,9 @@ __global__ void deformable_col2im_gpu_kernel( int yp = int(y) + dy; int xp = int(x) + dx; if (0 <= yp && yp < height && 0 <= xp && xp < width && - abs(y - yp) < 1 && abs(x - xp) < 1) { + std::abs(y - yp) < 1 && std::abs(x - xp) < 1) { int grad_pos = ((b * channels + c) * height + yp) * width + xp; - scalar_t weight = (1 - abs(y - yp)) * (1 - abs(x - xp)); + scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp)); atomicAdd(grad_im + grad_pos, weight * col[index]); } }