From c261e92d86593ac2718d6068321668020e3e56ef Mon Sep 17 00:00:00 2001 From: Licht Takeuchi Date: Mon, 9 Nov 2020 19:34:51 +0900 Subject: [PATCH] Add modulation input for DeformConv2D (#2791) * Add modulation input for DeformConv2D * lint * Patch for GPU CI * Remove bad cache on CI --- .circleci/config.yml | 14 +- .circleci/config.yml.in | 14 +- test/test_ops.py | 74 ++++-- torchvision/csrc/DeformConv.h | 84 +++++-- torchvision/csrc/cpu/DeformConv_cpu.cpp | 236 ++++++++++++++--- torchvision/csrc/cpu/vision_cpu.h | 37 +-- torchvision/csrc/cuda/DeformConv_cuda.cu | 307 ++++++++++++++++++----- torchvision/csrc/cuda/vision_cuda.h | 37 +-- torchvision/csrc/vision.cpp | 4 +- torchvision/ops/deform_conv.py | 28 ++- 10 files changed, 642 insertions(+), 193 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 3c1975eaa50..2f6c41357fe 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -455,6 +455,7 @@ jobs: resource_class: gpu.small environment: image_name: "pytorch/manylinux-cuda101" + PYTHON_VERSION: << parameters.python_version >> steps: - checkout - designate_upload_channel @@ -462,14 +463,9 @@ jobs: name: Generate cache key # This will refresh cache on Sundays, nightly build should generate new cache. command: echo "$(date +"%Y-%U")" > .circleci-weekly - - restore_cache: - - keys: - - env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }} - - run: name: Setup - command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/setup_env.sh + command: docker run -e PYTHON_VERSION -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/setup_env.sh - save_cache: key: env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }} @@ -533,6 +529,7 @@ jobs: name: windows-gpu environment: CUDA_VERSION: "10.1" + PYTHON_VERSION: << parameters.python_version >> steps: - checkout - designate_upload_channel @@ -540,11 +537,6 @@ jobs: name: Generate cache key # This will refresh cache on Sundays, nightly build should generate new cache. command: echo "$(date +"%Y-%U")" > .circleci-weekly - - restore_cache: - - keys: - - env-v1-windows-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/windows/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }} - - run: name: Setup command: .circleci/unittest/windows/scripts/setup_env.sh diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index f05022e47ca..50f7041afab 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -455,6 +455,7 @@ jobs: resource_class: gpu.small environment: image_name: "pytorch/manylinux-cuda101" + PYTHON_VERSION: << parameters.python_version >> steps: - checkout - designate_upload_channel @@ -462,14 +463,9 @@ jobs: name: Generate cache key # This will refresh cache on Sundays, nightly build should generate new cache. command: echo "$(date +"%Y-%U")" > .circleci-weekly - - restore_cache: - {% raw %} - keys: - - env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }} - {% endraw %} - run: name: Setup - command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/setup_env.sh + command: docker run -e PYTHON_VERSION -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/setup_env.sh - save_cache: {% raw %} key: env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }} @@ -533,6 +529,7 @@ jobs: name: windows-gpu environment: CUDA_VERSION: "10.1" + PYTHON_VERSION: << parameters.python_version >> steps: - checkout - designate_upload_channel @@ -540,11 +537,6 @@ jobs: name: Generate cache key # This will refresh cache on Sundays, nightly build should generate new cache. command: echo "$(date +"%Y-%U")" > .circleci-weekly - - restore_cache: - {% raw %} - keys: - - env-v1-windows-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/windows/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }} - {% endraw %} - run: name: Setup command: .circleci/unittest/windows/scripts/setup_env.sh diff --git a/test/test_ops.py b/test/test_ops.py index 5570d969cd7..1ba40d0da5f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -458,7 +458,7 @@ def test_new_empty_tensor(self): class DeformConvTester(OpTester, unittest.TestCase): - def expected_fn(self, x, weight, offset, bias, stride=1, padding=0, dilation=1): + def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1): stride_h, stride_w = _pair(stride) pad_h, pad_w = _pair(padding) dil_h, dil_w = _pair(dilation) @@ -489,12 +489,17 @@ def expected_fn(self, x, weight, offset, bias, stride=1, padding=0, dilation=1): c_in = weight_grp * in_c_per_weight_grp + c offset_grp = c_in // in_c_per_offset_grp - offset_idx = 2 * (offset_grp * (weight_h * weight_w) + di * weight_w + dj) + mask_idx = offset_grp * (weight_h * weight_w) + di * weight_w + dj + offset_idx = 2 * mask_idx pi = stride_h * i - pad_h + dil_h * di + offset[b, offset_idx, i, j] pj = stride_w * j - pad_w + dil_w * dj + offset[b, offset_idx + 1, i, j] - out[b, c_out, i, j] += (weight[c_out, c, di, dj] * + mask_value = 1.0 + if mask is not None: + mask_value = mask[b, mask_idx, i, j] + + out[b, c_out, i, j] += (mask_value * weight[c_out, c, di, dj] * bilinear_interpolate(x[b, c_in, :, :], pi, pj)) out += bias.view(1, n_out_channels, 1, 1) return out @@ -523,6 +528,9 @@ def get_fn_args(self, device, contiguous, batch_sz, dtype): offset = torch.randn(batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w, device=device, dtype=dtype, requires_grad=True) + mask = torch.randn(batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w, + device=device, dtype=dtype, requires_grad=True) + weight = torch.randn(n_out_channels, n_in_channels // n_weight_grps, weight_h, weight_w, device=device, dtype=dtype, requires_grad=True) @@ -531,9 +539,10 @@ def get_fn_args(self, device, contiguous, batch_sz, dtype): if not contiguous: x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2) offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1) + mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1) weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0) - return x, weight, offset, bias, stride, pad, dilation + return x, weight, offset, mask, bias, stride, pad, dilation def _test_forward(self, device, contiguous, dtype=None): dtype = self.dtype if dtype is None else dtype @@ -541,21 +550,28 @@ def _test_forward(self, device, contiguous, dtype=None): self._test_forward_with_batchsize(device, contiguous, batch_sz, dtype) def _test_forward_with_batchsize(self, device, contiguous, batch_sz, dtype): - x, _, offset, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype) + x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype) in_channels = 6 out_channels = 2 kernel_size = (3, 2) groups = 2 + tol = 1e-3 if dtype is torch.half else 1e-5 layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups).to(device=x.device, dtype=dtype) - res = layer(x, offset) + res = layer(x, offset, mask) weight = layer.weight.data bias = layer.bias.data - expected = self.expected_fn(x, weight, offset, bias, stride=stride, padding=padding, dilation=dilation) + expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation) + + self.assertTrue(torch.allclose(res.to(expected.dtype), expected, rtol=tol, atol=tol), + '\nres:\n{}\nexpected:\n{}'.format(res, expected)) + + # no modulation test + res = layer(x, offset) + expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation) - tol = 1e-3 if dtype is torch.half else 1e-5 self.assertTrue(torch.allclose(res.to(expected.dtype), expected, rtol=tol, atol=tol), '\nres:\n{}\nexpected:\n{}'.format(res, expected)) @@ -564,24 +580,46 @@ def _test_forward_with_batchsize(self, device, contiguous, batch_sz, dtype): wrong_offset = torch.rand_like(offset[:, :2]) res = layer(x, wrong_offset) + with self.assertRaises(RuntimeError): + wrong_mask = torch.rand_like(mask[:, :2]) + res = layer(x, offset, wrong_mask) + def _test_backward(self, device, contiguous): for batch_sz in [0, 33]: self._test_backward_with_batchsize(device, contiguous, batch_sz) def _test_backward_with_batchsize(self, device, contiguous, batch_sz): - x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, self.dtype) + x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(device, contiguous, + batch_sz, self.dtype) + + def func(x_, offset_, mask_, weight_, bias_): + return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, + padding=padding, dilation=dilation, mask=mask_) - def func(x_, offset_, weight_, bias_): - return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation) + gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5) + + def func_no_mask(x_, offset_, weight_, bias_): + return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, + padding=padding, dilation=dilation, mask=None) + + gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5) + + @torch.jit.script + def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_): + # type:(Tensor, Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor + return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_, + padding=pad_, dilation=dilation_, mask=mask_) - gradcheck(func, (x, offset, weight, bias), nondet_tol=1e-5) + gradcheck(lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation), + (x, offset, mask, weight, bias), nondet_tol=1e-5) @torch.jit.script - def script_func(x_, offset_, weight_, bias_, stride_, pad_, dilation_): - # type: (Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor - return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_) + def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_): + # type:(Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor + return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_, + padding=pad_, dilation=dilation_, mask=None) - gradcheck(lambda z, off, wei, bi: script_func(z, off, wei, bi, stride, padding, dilation), + gradcheck(lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation), (x, offset, weight, bias), nondet_tol=1e-5) # Test from https://github.com/pytorch/vision/issues/2598 @@ -593,17 +631,19 @@ def script_func(x_, offset_, weight_, bias_, stride_, pad_, dilation_): init_weight = torch.randn(9, 9, 3, 3, requires_grad=True) img = torch.randn(8, 9, 1000, 110) offset = torch.rand(8, 2 * 3 * 3, 1000, 110) + mask = torch.rand(8, 3 * 3, 1000, 110) if not contiguous: img = img.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2) offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1) + mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1) weight = init_weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0) else: weight = init_weight for d in ["cpu", "cuda"]: - out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1) + out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1, mask=mask.to(d)) out.mean().backward() if true_cpu_grads is None: true_cpu_grads = init_weight.grad diff --git a/torchvision/csrc/DeformConv.h b/torchvision/csrc/DeformConv.h index e09401f88ad..f8a8dba60e6 100644 --- a/torchvision/csrc/DeformConv.h +++ b/torchvision/csrc/DeformConv.h @@ -17,6 +17,7 @@ at::Tensor deform_conv2d( const at::Tensor& input, const at::Tensor& weight, const at::Tensor& offset, + const at::Tensor& mask, const at::Tensor& bias, int64_t stride_h, int64_t stride_w, @@ -25,7 +26,8 @@ at::Tensor deform_conv2d( int64_t dilation_h, int64_t dilation_w, int64_t groups, - int64_t offset_groups) { + int64_t offset_groups, + bool use_mask) { static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("torchvision::deform_conv2d", "") .typed(); @@ -33,6 +35,7 @@ at::Tensor deform_conv2d( input, weight, offset, + mask, bias, stride_h, stride_w, @@ -41,7 +44,8 @@ at::Tensor deform_conv2d( dilation_h, dilation_w, groups, - offset_groups); + offset_groups, + use_mask); } #if defined(WITH_CUDA) || defined(WITH_HIP) @@ -49,6 +53,7 @@ at::Tensor DeformConv2d_autocast( const at::Tensor& input, const at::Tensor& weight, const at::Tensor& offset, + const at::Tensor& mask, const at::Tensor& bias, int64_t stride_h, int64_t stride_w, @@ -57,12 +62,14 @@ at::Tensor DeformConv2d_autocast( int64_t dilation_h, int64_t dilation_w, int64_t groups, - int64_t offset_groups) { + int64_t offset_groups, + bool use_mask) { c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); return deform_conv2d( at::autocast::cached_cast(at::kFloat, input), at::autocast::cached_cast(at::kFloat, weight), at::autocast::cached_cast(at::kFloat, offset), + at::autocast::cached_cast(at::kFloat, mask), at::autocast::cached_cast(at::kFloat, bias), stride_h, stride_w, @@ -71,17 +78,19 @@ at::Tensor DeformConv2d_autocast( dilation_h, dilation_w, groups, - offset_groups) + offset_groups, + use_mask) .to(input.scalar_type()); } #endif -std::tuple +std::tuple _deform_conv2d_backward( const at::Tensor& grad, const at::Tensor& input, const at::Tensor& weight, const at::Tensor& offset, + const at::Tensor& mask, const at::Tensor& bias, int64_t stride_h, int64_t stride_w, @@ -90,7 +99,8 @@ _deform_conv2d_backward( int64_t dilation_h, int64_t dilation_w, int64_t groups, - int64_t offset_groups) { + int64_t offset_groups, + bool use_mask) { static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("torchvision::_deform_conv2d_backward", "") @@ -100,6 +110,7 @@ _deform_conv2d_backward( input, weight, offset, + mask, bias, stride_h, stride_w, @@ -108,7 +119,8 @@ _deform_conv2d_backward( dilation_h, dilation_w, groups, - offset_groups); + offset_groups, + use_mask); } class DeformConv2dFunction @@ -119,6 +131,7 @@ class DeformConv2dFunction const torch::autograd::Variable& input, const torch::autograd::Variable& weight, const torch::autograd::Variable& offset, + const torch::autograd::Variable& mask, const torch::autograd::Variable& bias, int64_t stride_h, int64_t stride_w, @@ -127,12 +140,14 @@ class DeformConv2dFunction int64_t dilation_h, int64_t dilation_w, int64_t groups, - int64_t offset_groups) { + int64_t offset_groups, + bool use_mask) { at::AutoNonVariableTypeMode g; auto output = deform_conv2d( input, weight, offset, + mask, bias, stride_h, stride_w, @@ -141,9 +156,10 @@ class DeformConv2dFunction dilation_h, dilation_w, groups, - offset_groups); + offset_groups, + use_mask); - ctx->save_for_backward({input, weight, offset, bias}); + ctx->save_for_backward({input, weight, offset, mask, bias}); ctx->saved_data["stride_h"] = stride_h; ctx->saved_data["stride_w"] = stride_w; ctx->saved_data["pad_h"] = pad_h; @@ -152,6 +168,7 @@ class DeformConv2dFunction ctx->saved_data["dilation_w"] = dilation_w; ctx->saved_data["groups"] = groups; ctx->saved_data["offset_groups"] = offset_groups; + ctx->saved_data["use_mask"] = use_mask; return { output, @@ -165,7 +182,8 @@ class DeformConv2dFunction auto input = saved[0]; auto weight = saved[1]; auto offset = saved[2]; - auto bias = saved[3]; + auto mask = saved[3]; + auto bias = saved[4]; auto stride_h = ctx->saved_data["stride_h"].toInt(); auto stride_w = ctx->saved_data["stride_w"].toInt(); @@ -175,12 +193,14 @@ class DeformConv2dFunction auto dilation_w = ctx->saved_data["dilation_w"].toInt(); auto groups = ctx->saved_data["groups"].toInt(); auto offset_groups = ctx->saved_data["offset_groups"].toInt(); + auto use_mask = ctx->saved_data["use_mask"].toBool(); auto grads = _deform_conv2d_backward( grad_output[0], input, weight, offset, + mask, bias, stride_h, stride_w, @@ -189,16 +209,19 @@ class DeformConv2dFunction dilation_h, dilation_w, groups, - offset_groups); + offset_groups, + use_mask); auto grad_input = std::get<0>(grads); auto grad_weight = std::get<1>(grads); auto grad_offset = std::get<2>(grads); - auto grad_bias = std::get<3>(grads); + auto grad_mask = std::get<3>(grads); + auto grad_bias = std::get<4>(grads); return { grad_input, grad_weight, grad_offset, + grad_mask, grad_bias, torch::autograd::Variable(), torch::autograd::Variable(), @@ -208,6 +231,7 @@ class DeformConv2dFunction torch::autograd::Variable(), torch::autograd::Variable(), torch::autograd::Variable(), + torch::autograd::Variable(), }; } }; @@ -222,6 +246,7 @@ class DeformConv2dBackwardFunction const torch::autograd::Variable& input, const torch::autograd::Variable& weight, const torch::autograd::Variable& offset, + const torch::autograd::Variable& mask, const torch::autograd::Variable& bias, int64_t stride_h, int64_t stride_w, @@ -230,13 +255,15 @@ class DeformConv2dBackwardFunction int64_t dilation_h, int64_t dilation_w, int64_t groups, - int64_t offset_groups) { + int64_t offset_groups, + bool use_mask) { at::AutoNonVariableTypeMode g; auto result = _deform_conv2d_backward( grad, input, weight, offset, + mask, bias, stride_h, stride_w, @@ -245,17 +272,20 @@ class DeformConv2dBackwardFunction dilation_h, dilation_w, groups, - offset_groups); + offset_groups, + use_mask); auto grad_input = std::get<0>(result); auto grad_weight = std::get<1>(result); auto grad_offset = std::get<2>(result); - auto grad_bias = std::get<3>(result); + auto grad_mask = std::get<3>(result); + auto grad_bias = std::get<4>(result); return { grad_input, grad_weight, grad_offset, + grad_mask, grad_bias, }; } @@ -271,6 +301,7 @@ at::Tensor DeformConv2d_autograd( const at::Tensor& input, const at::Tensor& weight, const at::Tensor& offset, + const at::Tensor& mask, const at::Tensor& bias, int64_t stride_h, int64_t stride_w, @@ -279,11 +310,13 @@ at::Tensor DeformConv2d_autograd( int64_t dilation_h, int64_t dilation_w, int64_t groups, - int64_t offset_groups) { + int64_t offset_groups, + bool use_mask) { return DeformConv2dFunction::apply( input, weight, offset, + mask, bias, stride_h, stride_w, @@ -292,15 +325,17 @@ at::Tensor DeformConv2d_autograd( dilation_h, dilation_w, groups, - offset_groups)[0]; + offset_groups, + use_mask)[0]; } -std::tuple +std::tuple DeformConv2d_backward_autograd( const at::Tensor& grad, const at::Tensor& input, const at::Tensor& weight, const at::Tensor& offset, + const at::Tensor& mask, const at::Tensor& bias, int64_t stride_h, int64_t stride_w, @@ -309,12 +344,14 @@ DeformConv2d_backward_autograd( int64_t dilation_h, int64_t dilation_w, int64_t groups, - int64_t offset_groups) { + int64_t offset_groups, + bool use_mask) { auto result = DeformConv2dBackwardFunction::apply( grad, input, weight, offset, + mask, bias, stride_h, stride_w, @@ -323,7 +360,8 @@ DeformConv2d_backward_autograd( dilation_h, dilation_w, groups, - offset_groups); + offset_groups, + use_mask); - return std::make_tuple(result[0], result[1], result[2], result[3]); -} \ No newline at end of file + return std::make_tuple(result[0], result[1], result[2], result[3], result[4]); +} diff --git a/torchvision/csrc/cpu/DeformConv_cpu.cpp b/torchvision/csrc/cpu/DeformConv_cpu.cpp index 18e845f5508..0212be55aa4 100644 --- a/torchvision/csrc/cpu/DeformConv_cpu.cpp +++ b/torchvision/csrc/cpu/DeformConv_cpu.cpp @@ -120,6 +120,7 @@ static void deformable_im2col_kernel( int n, const scalar_t* input, const scalar_t* offset, + const scalar_t* mask, int height, int width, int weight_h, @@ -135,6 +136,7 @@ static void deformable_im2col_kernel( int n_offset_grps, int out_h, int out_w, + bool use_mask, scalar_t* columns) { for (int index = 0; index != n; ++index) { const int out_x = index % out_w; @@ -157,16 +159,31 @@ static void deformable_im2col_kernel( (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h * out_w; + auto mask_ptr = mask; + if (use_mask) { + mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * + out_h * out_w; + } + for (int i = 0; i < weight_h; ++i) { for (int j = 0; j < weight_w; ++j) { - const int offset_idx = 2 * (i * weight_w + j); + const int mask_idx = i * weight_w + j; + const int offset_idx = 2 * mask_idx; + + scalar_t mask_value = 1; + if (use_mask) { + mask_value = + mask_ptr[mask_idx * (out_h * out_w) + out_y * out_w + out_x]; + } + const scalar_t offset_h = offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x]; const scalar_t offset_w = offset_ptr [(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x]; const scalar_t y = (out_y * stride_h - pad_h) + i * dil_h + offset_h; const scalar_t x = (out_x * stride_w - pad_w) + j * dil_w + offset_w; - *columns_ptr = bilinear_interpolate(input_ptr, height, width, y, x); + *columns_ptr = + mask_value * bilinear_interpolate(input_ptr, height, width, y, x); columns_ptr += batch_sz * out_h * out_w; } } @@ -176,6 +193,7 @@ static void deformable_im2col_kernel( static void deformable_im2col( const at::Tensor& input, const at::Tensor& data_offset, + const at::Tensor& data_mask, int n_in_channels, int height, int width, @@ -191,6 +209,7 @@ static void deformable_im2col( int out_w, int parallel_imgs, int deformable_group, + bool use_mask, at::Tensor data_col) { int num_kernels = n_in_channels * out_h * out_w * parallel_imgs; @@ -200,6 +219,7 @@ static void deformable_im2col( num_kernels, input.data_ptr(), data_offset.data_ptr(), + data_mask.data_ptr(), height, width, weight_h, @@ -215,6 +235,7 @@ static void deformable_im2col( deformable_group, out_h, out_w, + use_mask, data_col.data_ptr()); })); } @@ -232,6 +253,7 @@ at::Tensor DeformConv2d_forward_cpu( const at::Tensor& input_param, const at::Tensor& weight_param, const at::Tensor& offset_param, + const at::Tensor& mask_param, const at::Tensor& bias_param, int64_t stride_h, int64_t stride_w, @@ -240,14 +262,17 @@ at::Tensor DeformConv2d_forward_cpu( int64_t dil_h, int64_t dil_w, int64_t n_weight_grps, - int64_t n_offset_grps) { + int64_t n_offset_grps, + bool use_mask) { at::Tensor input = input_param.contiguous(); at::Tensor offset = offset_param.contiguous(); at::Tensor weight = weight_param.contiguous(); + at::Tensor mask = mask_param.contiguous(); at::Tensor bias = bias_param.contiguous(); TORCH_CHECK(input.ndimension() == 4); TORCH_CHECK(offset.ndimension() == 4); + TORCH_CHECK(!use_mask || mask.ndimension() == 4); TORCH_CHECK(weight.ndimension() == 4); TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); @@ -292,6 +317,12 @@ at::Tensor DeformConv2d_forward_cpu( offset.size(1), " expected: ", n_offset_grps * 2 * weight_h * weight_w); + TORCH_CHECK( + (!use_mask || mask.size(1) == n_offset_grps * weight_h * weight_w), + "mask.shape[1] is not valid: got: ", + mask.size(1), + " expected: ", + n_offset_grps * weight_h * weight_w); TORCH_CHECK(input.size(1) % n_offset_grps == 0); TORCH_CHECK( @@ -308,6 +339,19 @@ at::Tensor DeformConv2d_forward_cpu( ", ", out_w, ")"); + TORCH_CHECK((mask.size(0) == input.size(0)), "invalid batch size of mask"); + TORCH_CHECK( + (!use_mask || (mask.size(2) == out_h && mask.size(3) == out_w)), + "offset output dims: (", + mask.size(2), + ", ", + mask.size(3), + ") - ", + "computed output dims: (", + out_h, + ", ", + out_w, + ")"); TORCH_CHECK( out_h > 0 && out_w > 0, "Calculated output size too small - out_h: ", @@ -328,11 +372,21 @@ at::Tensor DeformConv2d_forward_cpu( out_w}); input = input.view( {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + offset = offset.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + + if (use_mask) { + mask = mask.view({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + } + at::Tensor out_buf = at::zeros( {batch_sz / n_parallel_imgs, out_channels, @@ -360,6 +414,7 @@ at::Tensor DeformConv2d_forward_cpu( deformable_im2col( input[b], offset[b], + mask[b], n_in_channels, in_h, in_w, @@ -375,6 +430,7 @@ at::Tensor DeformConv2d_forward_cpu( out_w, n_parallel_imgs, n_offset_grps, + use_mask, columns); columns = columns.view( @@ -406,6 +462,7 @@ static void deformable_col2im_kernel( int n, const scalar_t* col, const scalar_t* offset, + const scalar_t* mask, int channels, int height, int width, @@ -421,6 +478,7 @@ static void deformable_col2im_kernel( int n_offset_grps, int out_h, int out_w, + bool use_mask, scalar_t* grad_im) { for (int index = 0; index != n; ++index) { const int out_x = index % out_w; @@ -436,12 +494,27 @@ static void deformable_col2im_kernel( auto offset_ptr = offset + (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * out_h * out_w; - const int offset_h_ptr = - ((2 * (i * kernel_w + j)) * out_h + out_y) * out_w + out_x; - const int offset_w_ptr = - ((2 * (i * kernel_w + j) + 1) * out_h + out_y) * out_w + out_x; + + auto mask_ptr = mask; + if (use_mask) { + mask_ptr += (b * n_offset_grps + offset_grp) * kernel_h * kernel_w * + out_h * out_w; + } + + const int mask_idx = i * kernel_w + j; + const int offset_idx = 2 * mask_idx; + + const int offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x; + const int offset_w_ptr = ((offset_idx + 1) * out_h + out_y) * out_w + out_x; + const scalar_t offset_h = offset_ptr[offset_h_ptr]; const scalar_t offset_w = offset_ptr[offset_w_ptr]; + + scalar_t mask_value = 1; + if (use_mask) { + mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; + } + const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; @@ -453,7 +526,7 @@ static void deformable_col2im_kernel( std::abs(y - yp) < 1 && std::abs(x - xp) < 1) { int grad_pos = ((b * channels + c) * height + yp) * width + xp; scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp)); - grad_im[grad_pos] += weight * col[index]; + grad_im[grad_pos] += mask_value * weight * col[index]; } } } @@ -463,6 +536,7 @@ static void deformable_col2im_kernel( static void compute_grad_input( const at::Tensor& columns, const at::Tensor& offset, + const at::Tensor& mask, int channels, int height, int width, @@ -476,6 +550,7 @@ static void compute_grad_input( int dilation_w, int parallel_imgs, int n_offset_grps, + bool use_mask, at::Tensor grad_im) { int out_h = (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; @@ -490,6 +565,7 @@ static void compute_grad_input( num_kernels, columns.data_ptr(), offset.data_ptr(), + mask.data_ptr(), channels, height, width, @@ -505,6 +581,7 @@ static void compute_grad_input( n_offset_grps, out_h, out_w, + use_mask, grad_im.data_ptr()); })); } @@ -548,6 +625,7 @@ static void deformable_col2im_coord_kernel( const scalar_t* col, const scalar_t* im, const scalar_t* offset, + const scalar_t* mask, int channels, int height, int width, @@ -564,11 +642,17 @@ static void deformable_col2im_coord_kernel( int n_offset_grps, int out_h, int out_w, - scalar_t* grad_offset) { + bool use_mask, + scalar_t* grad_offset, + scalar_t* grad_mask) { for (int index = 0; index != n; ++index) { - scalar_t val = 0; + scalar_t grad_offset_val = 0; + scalar_t grad_mask_val = 0; + int w = index % out_w; int h = (index / out_w) % out_h; + int w_w = (index / (out_w * out_h * 2)) % weight_w; + int w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h; int c = (index / (out_w * out_h)) % offset_channels; int b = index / (out_w * out_h * offset_channels); @@ -586,6 +670,12 @@ static void deformable_col2im_coord_kernel( (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * out_h * out_w; + auto mask_ptr = mask; + if (use_mask) { + mask_ptr += (b * n_offset_grps + offset_grp) * weight_h * weight_w * + out_h * out_w; + } + const int offset_c = c - offset_grp * 2 * weight_h * weight_w; const bool is_y_direction = offset_c % 2 == 0; @@ -598,30 +688,55 @@ static void deformable_col2im_coord_kernel( int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w; int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h; + const int mask_idx = i * weight_w + j; + const int offset_h_idx = - (((2 * (i * weight_w + j)) * out_h + out_y) * out_w + out_x); + (((2 * mask_idx) * out_h + out_y) * out_w + out_x); const int offset_w_idx = - (((2 * (i * weight_w + j) + 1) * out_h + out_y) * out_w + out_x); + (((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x); const scalar_t offset_h = offset_ptr[offset_h_idx]; const scalar_t offset_w = offset_ptr[offset_w_idx]; + scalar_t mask_value = 1; + if (use_mask) { + mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; + } + scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; const scalar_t weight = get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction); - val += weight * col_ptr[col_pos]; + grad_offset_val += mask_value * weight * col_ptr[col_pos]; + + if (use_mask && is_y_direction) { + grad_mask_val += col_ptr[col_pos] * + bilinear_interpolate(im_ptr, height, width, y, x); + } + im_ptr += height * width; } - grad_offset[index] = val; + grad_offset[index] = grad_offset_val; + + if (use_mask && is_y_direction) { + const int idx = + ((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w + + w_w) * + out_h + + h) * + out_w + + w; + grad_mask[idx] = grad_mask_val; + } } } -static void compute_grad_offset( +static void compute_grad_offset_and_mask( const at::Tensor& columns, const at::Tensor& input, const at::Tensor& offset, + const at::Tensor& mask, int channels, int height, int width, @@ -635,7 +750,9 @@ static void compute_grad_offset( int dilation_w, int parallel_imgs, int n_offset_grps, - at::Tensor grad_offset) { + bool use_mask, + at::Tensor grad_offset, + at::Tensor grad_mask) { int out_h = (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; int out_w = @@ -650,6 +767,7 @@ static void compute_grad_offset( columns.data_ptr(), input.data_ptr(), offset.data_ptr(), + mask.data_ptr(), channels, height, width, @@ -666,14 +784,18 @@ static void compute_grad_offset( n_offset_grps, out_h, out_w, - grad_offset.data_ptr()); + use_mask, + grad_offset.data_ptr(), + grad_mask.data_ptr()); })); } -static std::tuple deform_conv2d_backward_input_cpu( +static std::tuple +deform_conv2d_backward_input_cpu( at::Tensor input, at::Tensor weight, at::Tensor offset, + at::Tensor mask, at::Tensor grad_out, int stride_h, int stride_w, @@ -683,7 +805,8 @@ static std::tuple deform_conv2d_backward_input_cpu( int dil_w, int n_weight_grps, int n_offset_grps, - int n_parallel_imgs) { + int n_parallel_imgs, + bool use_mask) { int batch_sz = input.size(0); int n_in_channels = input.size(1); int in_h = input.size(2); @@ -700,9 +823,12 @@ static std::tuple deform_conv2d_backward_input_cpu( auto grad_input = at::zeros_like(input); auto grad_offset = at::zeros_like(offset); + auto grad_mask = at::zeros_like(mask); + if (batch_sz == 0) { - return std::make_tuple(grad_input, grad_offset); + return std::make_tuple(grad_input, grad_offset, grad_mask); } + auto columns = at::empty( {n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, input.options()); @@ -712,6 +838,7 @@ static std::tuple deform_conv2d_backward_input_cpu( {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); input = input.reshape( {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + grad_offset = grad_offset.reshape({batch_sz / n_parallel_imgs, n_parallel_imgs, n_offset_grps * 2 * weight_h * weight_w, @@ -723,6 +850,19 @@ static std::tuple deform_conv2d_backward_input_cpu( out_h, out_w}); + if (use_mask) { + grad_mask = grad_mask.reshape({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + mask = mask.reshape({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + } + grad_out = grad_out .reshape({batch_sz / n_parallel_imgs, n_parallel_imgs, @@ -749,10 +889,11 @@ static std::tuple deform_conv2d_backward_input_cpu( weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1)); } - compute_grad_offset( + compute_grad_offset_and_mask( columns, input[elt], offset[elt], + mask[elt], n_in_channels, in_h, in_w, @@ -766,11 +907,14 @@ static std::tuple deform_conv2d_backward_input_cpu( dil_w, n_parallel_imgs, n_offset_grps, - grad_offset[elt]); + use_mask, + grad_offset[elt], + grad_mask[elt]); compute_grad_input( columns, offset[elt], + mask[elt], n_in_channels, in_h, in_w, @@ -784,6 +928,7 @@ static std::tuple deform_conv2d_backward_input_cpu( dil_w, n_parallel_imgs, n_offset_grps, + use_mask, grad_input[elt]); } @@ -791,13 +936,19 @@ static std::tuple deform_conv2d_backward_input_cpu( grad_offset = grad_offset.view( {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); - return std::make_tuple(grad_input, grad_offset); + if (use_mask) { + grad_mask = grad_mask.view( + {batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w}); + } + + return std::make_tuple(grad_input, grad_offset, grad_mask); } static at::Tensor deform_conv2d_backward_parameters_cpu( at::Tensor input, const at::Tensor& weight, at::Tensor offset, + at::Tensor mask, const at::Tensor& grad_out, int stride_h, int stride_w, @@ -807,7 +958,8 @@ static at::Tensor deform_conv2d_backward_parameters_cpu( int dil_w, int n_weight_grps, int n_offset_grps, - int n_parallel_imgs) { + int n_parallel_imgs, + bool use_mask) { int batch_sz = input.size(0); int n_in_channels = input.size(1); int in_h = input.size(2); @@ -839,12 +991,21 @@ static at::Tensor deform_conv2d_backward_parameters_cpu( input = input.reshape( {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + offset = offset.reshape({batch_sz / n_parallel_imgs, n_parallel_imgs, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + if (use_mask) { + mask = mask.reshape({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + } + grad_weight = grad_weight.view({n_weight_grps, grad_weight.size(0) / n_weight_grps, grad_weight.size(1), @@ -861,6 +1022,7 @@ static at::Tensor deform_conv2d_backward_parameters_cpu( deformable_im2col( input[elt], offset[elt], + mask[elt], n_in_channels, in_h, in_w, @@ -876,6 +1038,7 @@ static at::Tensor deform_conv2d_backward_parameters_cpu( out_w, n_parallel_imgs, n_offset_grps, + use_mask, columns); for (int g = 0; g < n_weight_grps; g++) { @@ -895,12 +1058,13 @@ static at::Tensor deform_conv2d_backward_parameters_cpu( return grad_weight; } -std::tuple +std::tuple DeformConv2d_backward_cpu( const at::Tensor& grad_out_param, const at::Tensor& input_param, const at::Tensor& weight_param, const at::Tensor& offset_param, + const at::Tensor& mask_param, const at::Tensor& bias_param, int64_t stride_h, int64_t stride_w, @@ -909,21 +1073,24 @@ DeformConv2d_backward_cpu( int64_t dil_h, int64_t dil_w, int64_t n_weight_grps, - int64_t n_offset_grps) { + int64_t n_offset_grps, + bool use_mask) { at::Tensor grad_out = grad_out_param.contiguous(); at::Tensor input = input_param.contiguous(); at::Tensor weight = weight_param.contiguous(); at::Tensor offset = offset_param.contiguous(); + at::Tensor mask = mask_param.contiguous(); at::Tensor bias = bias_param.contiguous(); const int batch_sz = input.size(0); const int n_parallel_imgs = get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); - auto grad_input_and_offset = deform_conv2d_backward_input_cpu( + auto grad_input_and_offset_and_mask = deform_conv2d_backward_input_cpu( input, weight, offset, + mask, grad_out, stride_h, stride_w, @@ -933,15 +1100,18 @@ DeformConv2d_backward_cpu( dil_w, n_weight_grps, n_offset_grps, - n_parallel_imgs); + n_parallel_imgs, + use_mask); - auto grad_input = std::get<0>(grad_input_and_offset); - auto grad_offset = std::get<1>(grad_input_and_offset); + auto grad_input = std::get<0>(grad_input_and_offset_and_mask); + auto grad_offset = std::get<1>(grad_input_and_offset_and_mask); + auto grad_mask = std::get<2>(grad_input_and_offset_and_mask); auto grad_weight = deform_conv2d_backward_parameters_cpu( input, weight, offset, + mask, grad_out, stride_h, stride_w, @@ -951,9 +1121,11 @@ DeformConv2d_backward_cpu( dil_w, n_weight_grps, n_offset_grps, - n_parallel_imgs); + n_parallel_imgs, + use_mask); auto grad_bias = at::ones_like(bias) * grad_out.sum({0, 2, 3}); - return std::make_tuple(grad_input, grad_weight, grad_offset, grad_bias); + return std::make_tuple( + grad_input, grad_weight, grad_offset, grad_mask, grad_bias); } diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index 14def9d324f..d5bfcc0de24 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -6,6 +6,7 @@ VISION_API at::Tensor DeformConv2d_forward_cpu( const at::Tensor& input, const at::Tensor& weight, const at::Tensor& offset, + const at::Tensor& mask, const at::Tensor& bias, int64_t stride_h, int64_t stride_w, @@ -14,23 +15,27 @@ VISION_API at::Tensor DeformConv2d_forward_cpu( int64_t dilation_h, int64_t dilation_w, int64_t groups, - int64_t deformable_groups); + int64_t deformable_groups, + bool use_mask); -VISION_API std::tuple -DeformConv2d_backward_cpu( - const at::Tensor& grad_out, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t deformable_groups); +VISION_API std:: + tuple + DeformConv2d_backward_cpu( + const at::Tensor& grad_out, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t deformable_groups, + bool use_mask); VISION_API at::Tensor nms_cpu( const at::Tensor& dets, diff --git a/torchvision/csrc/cuda/DeformConv_cuda.cu b/torchvision/csrc/cuda/DeformConv_cuda.cu index dc53b26c8a0..c6e9a9278ed 100644 --- a/torchvision/csrc/cuda/DeformConv_cuda.cu +++ b/torchvision/csrc/cuda/DeformConv_cuda.cu @@ -78,12 +78,19 @@ #include #include -const unsigned int CUDA_NUM_THREADS = 1024; const int kMaxParallelImgs = 32; -inline unsigned int GET_BLOCKS(const unsigned int N) { - unsigned int kMaxGridNum = at::cuda::getCurrentDeviceProperties()->maxGridSize[0]; - return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); +inline unsigned int GET_THREADS() { + if (at::cuda::getCurrentDeviceProperties()->major >= 6) { + return 1024; + } + return 512; +} + +inline unsigned int GET_BLOCKS(const unsigned int THREADS, const unsigned int N) { + unsigned int kMaxGridNum = + at::cuda::getCurrentDeviceProperties()->maxGridSize[0]; + return std::min(kMaxGridNum, (N + THREADS - 1) / THREADS); } template @@ -130,6 +137,7 @@ __global__ void deformable_im2col_gpu_kernel( int n, const scalar_t* input_ptr, const scalar_t* offset_ptr, + const scalar_t* mask_ptr, int height, int width, int weight_h, @@ -145,6 +153,7 @@ __global__ void deformable_im2col_gpu_kernel( int n_offset_grps, int out_h, int out_w, + bool use_mask, scalar_t* columns_ptr) { CUDA_1D_KERNEL_LOOP(index, n) { const int out_x = index % out_w; @@ -166,16 +175,30 @@ __global__ void deformable_im2col_gpu_kernel( offset_ptr += (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h * out_w; + if (use_mask) { + mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * + out_h * out_w; + } + for (int i = 0; i < weight_h; ++i) { for (int j = 0; j < weight_w; ++j) { - const int offset_idx = 2 * (i * weight_w + j); + const int mask_idx = i * weight_w + j; + const int offset_idx = 2 * mask_idx; + + scalar_t mask_value = 1; + if (use_mask) { + mask_value = + mask_ptr[mask_idx * (out_h * out_w) + out_y * out_w + out_x]; + } + const scalar_t offset_h = offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x]; const scalar_t offset_w = offset_ptr [(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x]; const scalar_t y = (out_y * stride_h - pad_h) + i * dil_h + offset_h; const scalar_t x = (out_x * stride_w - pad_w) + j * dil_w + offset_w; - *columns_ptr = bilinear_interpolate(input_ptr, height, width, y, x); + *columns_ptr = + mask_value * bilinear_interpolate(input_ptr, height, width, y, x); columns_ptr += batch_sz * out_h * out_w; } } @@ -185,6 +208,7 @@ __global__ void deformable_im2col_gpu_kernel( static void deformable_im2col( const at::Tensor& input, const at::Tensor& data_offset, + const at::Tensor& data_mask, int n_in_channels, int height, int width, @@ -200,17 +224,22 @@ static void deformable_im2col( int out_w, int parallel_imgs, int deformable_group, + bool use_mask, at::Tensor data_col) { int num_kernels = n_in_channels * out_h * out_w * parallel_imgs; + const unsigned int threads = GET_THREADS(); + const unsigned int blocks = GET_BLOCKS(threads, num_kernels); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( input.scalar_type(), "deformable_im2col_gpu", ([&] { deformable_im2col_gpu_kernel<<< - GET_BLOCKS(num_kernels), - CUDA_NUM_THREADS>>>( + blocks, + threads>>>( num_kernels, input.data_ptr(), data_offset.data_ptr(), + data_mask.data_ptr(), height, width, weight_h, @@ -226,6 +255,7 @@ static void deformable_im2col( deformable_group, out_h, out_w, + use_mask, data_col.data_ptr()); })); @@ -248,6 +278,7 @@ at::Tensor DeformConv2d_forward_cuda( const at::Tensor& input_param, const at::Tensor& weight_param, const at::Tensor& offset_param, + const at::Tensor& mask_param, const at::Tensor& bias_param, int64_t stride_h, int64_t stride_w, @@ -256,14 +287,17 @@ at::Tensor DeformConv2d_forward_cuda( int64_t dil_h, int64_t dil_w, int64_t n_weight_grps, - int64_t n_offset_grps) { + int64_t n_offset_grps, + bool use_mask) { at::Tensor input = input_param.contiguous(); at::Tensor offset = offset_param.contiguous(); at::Tensor weight = weight_param.contiguous(); + at::Tensor mask = mask_param.contiguous(); at::Tensor bias = bias_param.contiguous(); TORCH_CHECK(input.ndimension() == 4); TORCH_CHECK(offset.ndimension() == 4); + TORCH_CHECK(!use_mask || mask.ndimension() == 4); TORCH_CHECK(weight.ndimension() == 4); TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); @@ -309,6 +343,12 @@ at::Tensor DeformConv2d_forward_cuda( offset.size(1), " expected: ", n_offset_grps * 2 * weight_h * weight_w); + TORCH_CHECK( + (!use_mask || mask.size(1) == n_offset_grps * weight_h * weight_w), + "mask.shape[1] is not valid: got: ", + mask.size(1), + " expected: ", + n_offset_grps * weight_h * weight_w); TORCH_CHECK(input.size(1) % n_offset_grps == 0); TORCH_CHECK( @@ -325,6 +365,19 @@ at::Tensor DeformConv2d_forward_cuda( ", ", out_w, ")"); + TORCH_CHECK((mask.size(0) == input.size(0)), "invalid batch size of mask"); + TORCH_CHECK( + (!use_mask || (mask.size(2) == out_h && mask.size(3) == out_w)), + "mask output dims: (", + mask.size(2), + ", ", + mask.size(3), + ") - ", + "computed output dims: (", + out_h, + ", ", + out_w, + ")"); TORCH_CHECK( out_h > 0 && out_w > 0, "Calculated output size too small - out_h: ", @@ -345,11 +398,21 @@ at::Tensor DeformConv2d_forward_cuda( out_w}); input = input.view( {batch_sz / n_parallel_imgs, n_parallel_imgs, in_channels, in_h, in_w}); + offset = offset.view({batch_sz / n_parallel_imgs, n_parallel_imgs, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + + if (use_mask) { + mask = mask.view({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + } + at::Tensor out_buf = at::zeros( {batch_sz / n_parallel_imgs, out_channels, @@ -377,6 +440,7 @@ at::Tensor DeformConv2d_forward_cuda( deformable_im2col( input[b], offset[b], + mask[b], in_channels, in_h, in_w, @@ -392,6 +456,7 @@ at::Tensor DeformConv2d_forward_cuda( out_w, n_parallel_imgs, n_offset_grps, + use_mask, columns); columns = columns.view( @@ -402,8 +467,8 @@ at::Tensor DeformConv2d_forward_cuda( .addmm_(weight[g].flatten(1), columns[g]) .view_as(out_buf[b][g]); } - columns = columns.view( - {columns.size(0) * columns.size(1), columns.size(2)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); } out_buf = out_buf.view({batch_sz / n_parallel_imgs, @@ -423,6 +488,7 @@ __global__ void deformable_col2im_gpu_kernel( int n, const scalar_t* col, const scalar_t* offset_ptr, + const scalar_t* mask_ptr, int channels, int height, int width, @@ -438,6 +504,7 @@ __global__ void deformable_col2im_gpu_kernel( int n_offset_grps, int out_h, int out_w, + bool use_mask, scalar_t* grad_im) { CUDA_1D_KERNEL_LOOP(index, n) { const int out_x = index % out_w; @@ -452,12 +519,26 @@ __global__ void deformable_col2im_gpu_kernel( offset_ptr += (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * out_h * out_w; - const int offset_h_ptr = - ((2 * (i * kernel_w + j)) * out_h + out_y) * out_w + out_x; - const int offset_w_ptr = - ((2 * (i * kernel_w + j) + 1) * out_h + out_y) * out_w + out_x; + + if (use_mask) { + mask_ptr += (b * n_offset_grps + offset_grp) * kernel_h * kernel_w * + out_h * out_w; + } + + const int mask_idx = i * kernel_w + j; + const int offset_idx = 2 * mask_idx; + + const int offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x; + const int offset_w_ptr = ((offset_idx + 1) * out_h + out_y) * out_w + out_x; + const scalar_t offset_h = offset_ptr[offset_h_ptr]; const scalar_t offset_w = offset_ptr[offset_w_ptr]; + + scalar_t mask_value = 1; + if (use_mask) { + mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; + } + const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; @@ -469,7 +550,7 @@ __global__ void deformable_col2im_gpu_kernel( std::abs(y - yp) < 1 && std::abs(x - xp) < 1) { int grad_pos = ((b * channels + c) * height + yp) * width + xp; scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp)); - atomicAdd(grad_im + grad_pos, weight * col[index]); + atomicAdd(grad_im + grad_pos, mask_value * weight * col[index]); } } } @@ -479,6 +560,7 @@ __global__ void deformable_col2im_gpu_kernel( static void compute_grad_input( const at::Tensor& columns, const at::Tensor& offset, + const at::Tensor& mask, int channels, int height, int width, @@ -492,6 +574,7 @@ static void compute_grad_input( int dilation_w, int parallel_imgs, int n_offset_grps, + bool use_mask, at::Tensor grad_im) { int out_h = (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; @@ -500,14 +583,18 @@ static void compute_grad_input( int num_kernels = channels * weight_h * weight_w * out_h * out_w * parallel_imgs; + const unsigned int threads = GET_THREADS(); + const unsigned int blocks = GET_BLOCKS(threads, num_kernels); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( columns.scalar_type(), "deformable_col2im_gpu", ([&] { deformable_col2im_gpu_kernel<<< - GET_BLOCKS(num_kernels), - CUDA_NUM_THREADS>>>( + blocks, + threads>>>( num_kernels, columns.data_ptr(), offset.data_ptr(), + mask.data_ptr(), channels, height, width, @@ -523,6 +610,7 @@ static void compute_grad_input( n_offset_grps, out_h, out_w, + use_mask, grad_im.data_ptr()); })); @@ -571,6 +659,7 @@ __global__ void deformable_col2im_coord_gpu_kernel( const scalar_t* col_ptr, const scalar_t* im_ptr, const scalar_t* offset_ptr, + const scalar_t* mask_ptr, int channels, int height, int width, @@ -587,11 +676,17 @@ __global__ void deformable_col2im_coord_gpu_kernel( int n_offset_grps, int out_h, int out_w, - scalar_t* grad_offset) { + const bool use_mask, + scalar_t* grad_offset, + scalar_t* grad_mask) { CUDA_1D_KERNEL_LOOP(index, n) { - scalar_t val = 0; + scalar_t grad_offset_val = 0; + scalar_t grad_mask_val = 0; + int w = index % out_w; int h = (index / out_w) % out_h; + int w_w = (index / (out_w * out_h * 2)) % weight_w; + int w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h; int c = (index / (out_w * out_h)) % offset_channels; int b = index / (out_w * out_h * offset_channels); @@ -607,6 +702,11 @@ __global__ void deformable_col2im_coord_gpu_kernel( offset_ptr += (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * out_h * out_w; + if (use_mask) { + mask_ptr += (b * n_offset_grps + offset_grp) * weight_h * weight_w * + out_h * out_w; + } + const int offset_c = c - offset_grp * 2 * weight_h * weight_w; const bool is_y_direction = offset_c % 2 == 0; @@ -619,30 +719,55 @@ __global__ void deformable_col2im_coord_gpu_kernel( int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w; int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h; + const int mask_idx = i * weight_w + j; + const int offset_h_ptr = - (((2 * (i * weight_w + j)) * out_h + out_y) * out_w + out_x); + (((2 * mask_idx) * out_h + out_y) * out_w + out_x); const int offset_w_ptr = - (((2 * (i * weight_w + j) + 1) * out_h + out_y) * out_w + out_x); + (((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x); const scalar_t offset_h = offset_ptr[offset_h_ptr]; const scalar_t offset_w = offset_ptr[offset_w_ptr]; + scalar_t mask_value = 1; + if (use_mask) { + mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; + } + scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; const scalar_t weight = get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction); - val += weight * col_ptr[col_pos]; + grad_offset_val += mask_value * weight * col_ptr[col_pos]; + + if (use_mask && is_y_direction) { + grad_mask_val += col_ptr[col_pos] * + bilinear_interpolate(im_ptr, height, width, y, x); + } + im_ptr += height * width; } - grad_offset[index] = val; + grad_offset[index] = grad_offset_val; + + if (use_mask && is_y_direction) { + const int idx = + ((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w + + w_w) * + out_h + + h) * + out_w + + w; + grad_mask[idx] = grad_mask_val; + } } } -static void compute_grad_offset( +static void compute_grad_offset_and_mask( const at::Tensor& columns, const at::Tensor& input, const at::Tensor& offset, + const at::Tensor& mask, int channels, int height, int width, @@ -656,7 +781,9 @@ static void compute_grad_offset( int dilation_w, int parallel_imgs, int n_offset_grps, - at::Tensor grad_offset) { + bool use_mask, + at::Tensor grad_offset, + at::Tensor grad_mask) { int out_h = (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; int out_w = @@ -664,15 +791,19 @@ static void compute_grad_offset( int num_kernels = out_h * out_w * 2 * weight_h * weight_w * n_offset_grps * parallel_imgs; + const unsigned int threads = GET_THREADS(); + const unsigned int blocks = GET_BLOCKS(threads, num_kernels); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( columns.scalar_type(), "deformable_col2im_coord_gpu", ([&] { deformable_col2im_coord_gpu_kernel<<< - GET_BLOCKS(num_kernels), - CUDA_NUM_THREADS>>>( + blocks, + threads>>>( num_kernels, columns.data_ptr(), input.data_ptr(), offset.data_ptr(), + mask.data_ptr(), channels, height, width, @@ -689,19 +820,23 @@ static void compute_grad_offset( n_offset_grps, out_h, out_w, - grad_offset.data_ptr()); + use_mask, + grad_offset.data_ptr(), + grad_mask.data_ptr()); })); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { - printf("error in compute_grad_offset: %s\n", cudaGetErrorString(err)); + printf( + "error in compute_grad_offset_and_mask: %s\n", cudaGetErrorString(err)); } } -static std::tuple deform_conv2d_backward_input_cuda( +static std::tuple deform_conv2d_backward_input_cuda( at::Tensor input, at::Tensor weight, at::Tensor offset, + at::Tensor mask, at::Tensor grad_out, int stride_h, int stride_w, @@ -711,7 +846,8 @@ static std::tuple deform_conv2d_backward_input_cuda( int dil_w, int n_weight_grps, int n_offset_grps, - int n_parallel_imgs) { + int n_parallel_imgs, + bool use_mask) { at::DeviceGuard guard(input.device()); int batch_sz = input.size(0); @@ -730,9 +866,12 @@ static std::tuple deform_conv2d_backward_input_cuda( auto grad_input = at::zeros_like(input); auto grad_offset = at::zeros_like(offset); + auto grad_mask = at::zeros_like(mask); + if (batch_sz == 0) { - return std::make_tuple(grad_input, grad_offset); + return std::make_tuple(grad_input, grad_offset, grad_mask); } + auto columns = at::empty( {n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, input.options()); @@ -742,6 +881,7 @@ static std::tuple deform_conv2d_backward_input_cuda( {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); input = input.reshape( {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + grad_offset = grad_offset.reshape({batch_sz / n_parallel_imgs, n_parallel_imgs, n_offset_grps * 2 * weight_h * weight_w, @@ -753,12 +893,27 @@ static std::tuple deform_conv2d_backward_input_cuda( out_h, out_w}); - grad_out = grad_out.reshape({batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_weight_grps, - n_out_channels / n_weight_grps, - out_h, - out_w}).permute({0, 2, 3, 1, 4, 5}); + if (use_mask) { + grad_mask = grad_mask.reshape({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + mask = mask.reshape({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + } + + grad_out = grad_out + .reshape({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_weight_grps, + n_out_channels / n_weight_grps, + out_h, + out_w}) + .permute({0, 2, 3, 1, 4, 5}); weight = weight.reshape({n_weight_grps, weight.size(0) / n_weight_grps, @@ -776,10 +931,11 @@ static std::tuple deform_conv2d_backward_input_cuda( weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1)); } - compute_grad_offset( + compute_grad_offset_and_mask( columns, input[elt], offset[elt], + mask[elt], n_in_channels, in_h, in_w, @@ -793,11 +949,14 @@ static std::tuple deform_conv2d_backward_input_cuda( dil_w, n_parallel_imgs, n_offset_grps, - grad_offset[elt]); + use_mask, + grad_offset[elt], + grad_mask[elt]); compute_grad_input( columns, offset[elt], + mask[elt], n_in_channels, in_h, in_w, @@ -811,21 +970,27 @@ static std::tuple deform_conv2d_backward_input_cuda( dil_w, n_parallel_imgs, n_offset_grps, + use_mask, grad_input[elt]); } - grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w}); grad_offset = grad_offset.view( {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); - return std::make_tuple(grad_input, grad_offset); + if (use_mask) { + grad_mask = grad_mask.view( + {batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w}); + } + + return std::make_tuple(grad_input, grad_offset, grad_mask); } static at::Tensor deform_conv2d_backward_parameters_cuda( at::Tensor input, const at::Tensor& weight, at::Tensor offset, + at::Tensor mask, const at::Tensor& grad_out, int stride_h, int stride_w, @@ -835,7 +1000,8 @@ static at::Tensor deform_conv2d_backward_parameters_cuda( int dil_w, int n_weight_grps, int n_offset_grps, - int n_parallel_imgs) { + int n_parallel_imgs, + bool use_mask) { at::DeviceGuard guard(input.device()); int batch_sz = input.size(0); @@ -857,23 +1023,33 @@ static at::Tensor deform_conv2d_backward_parameters_cuda( return grad_weight; } - at::Tensor grad_out_buf = grad_out.reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_weight_grps, - n_out_channels / n_weight_grps, - out_h, - out_w} - ).permute({0, 2, 3, 1, 4, 5}).contiguous(); + at::Tensor grad_out_buf = grad_out + .reshape({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_weight_grps, + n_out_channels / n_weight_grps, + out_h, + out_w}) + .permute({0, 2, 3, 1, 4, 5}) + .contiguous(); input = input.reshape( {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + offset = offset.reshape({batch_sz / n_parallel_imgs, n_parallel_imgs, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + if (use_mask) { + mask = mask.reshape({batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + } + grad_weight = grad_weight.reshape({n_weight_grps, grad_weight.size(0) / n_weight_grps, grad_weight.size(1), @@ -890,6 +1066,7 @@ static at::Tensor deform_conv2d_backward_parameters_cuda( deformable_im2col( input[elt], offset[elt], + mask[elt], n_in_channels, in_h, in_w, @@ -905,6 +1082,7 @@ static at::Tensor deform_conv2d_backward_parameters_cuda( out_w, n_parallel_imgs, n_offset_grps, + use_mask, columns); for (int g = 0; g < n_weight_grps; g++) { @@ -924,12 +1102,13 @@ static at::Tensor deform_conv2d_backward_parameters_cuda( return grad_weight; } -std::tuple +std::tuple DeformConv2d_backward_cuda( const at::Tensor& grad_out_param, const at::Tensor& input_param, const at::Tensor& weight_param, const at::Tensor& offset_param, + const at::Tensor& mask_param, const at::Tensor& bias_param, int64_t stride_h, int64_t stride_w, @@ -938,21 +1117,24 @@ DeformConv2d_backward_cuda( int64_t dil_h, int64_t dil_w, int64_t n_weight_grps, - int64_t n_offset_grps) { + int64_t n_offset_grps, + bool use_mask) { at::Tensor grad_out = grad_out_param.contiguous(); at::Tensor input = input_param.contiguous(); at::Tensor weight = weight_param.contiguous(); at::Tensor offset = offset_param.contiguous(); + at::Tensor mask = mask_param.contiguous(); at::Tensor bias = bias_param.contiguous(); const int batch_sz = input.size(0); const int n_parallel_imgs = get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); - auto grad_input_and_offset = deform_conv2d_backward_input_cuda( + auto grad_input_and_offset_and_mask = deform_conv2d_backward_input_cuda( input, weight, offset, + mask, grad_out, stride_h, stride_w, @@ -962,15 +1144,18 @@ DeformConv2d_backward_cuda( dil_w, n_weight_grps, n_offset_grps, - n_parallel_imgs); + n_parallel_imgs, + use_mask); - auto grad_input = std::get<0>(grad_input_and_offset); - auto grad_offset = std::get<1>(grad_input_and_offset); + auto grad_input = std::get<0>(grad_input_and_offset_and_mask); + auto grad_offset = std::get<1>(grad_input_and_offset_and_mask); + auto grad_mask = std::get<2>(grad_input_and_offset_and_mask); auto grad_weight = deform_conv2d_backward_parameters_cuda( input, weight, offset, + mask, grad_out, stride_h, stride_w, @@ -980,10 +1165,12 @@ DeformConv2d_backward_cuda( dil_w, n_weight_grps, n_offset_grps, - n_parallel_imgs); + n_parallel_imgs, + use_mask); auto value = grad_out.sum({0, 2, 3}); auto grad_bias = at::ones_like(bias) * value; - return std::make_tuple(grad_input, grad_weight, grad_offset, grad_bias); + return std::make_tuple( + grad_input, grad_weight, grad_offset, grad_mask, grad_bias); } diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index 731d119cf75..bf57f1c7967 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -6,6 +6,7 @@ VISION_API at::Tensor DeformConv2d_forward_cuda( const at::Tensor& input, const at::Tensor& weight, const at::Tensor& offset, + const at::Tensor& mask, const at::Tensor& bias, int64_t stride_h, int64_t stride_w, @@ -14,23 +15,27 @@ VISION_API at::Tensor DeformConv2d_forward_cuda( int64_t dilation_h, int64_t dilation_w, int64_t groups, - int64_t deformable_groups); + int64_t deformable_groups, + bool use_mask); -VISION_API std::tuple -DeformConv2d_backward_cuda( - const at::Tensor& grad_out, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t deformable_groups); +VISION_API std:: + tuple + DeformConv2d_backward_cuda( + const at::Tensor& grad_out, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t deformable_groups, + bool use_mask); VISION_API at::Tensor nms_cuda( const at::Tensor& dets, diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index bd9a770473d..abfd78c5461 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -46,9 +46,9 @@ int64_t cuda_version() noexcept { TORCH_LIBRARY(torchvision, m) { m.def( - "deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> Tensor"); + "deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor"); m.def( - "_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> (Tensor, Tensor, Tensor, Tensor)"); + "_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"); m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"); m.def( "ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)"); diff --git a/torchvision/ops/deform_conv.py b/torchvision/ops/deform_conv.py index b403da5585f..5377df56146 100644 --- a/torchvision/ops/deform_conv.py +++ b/torchvision/ops/deform_conv.py @@ -17,6 +17,7 @@ def deform_conv2d( stride: Tuple[int, int] = (1, 1), padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), + mask: Optional[Tensor] = None, ) -> Tensor: """ Performs Deformable Convolution, described in Deformable Convolutional Networks @@ -33,6 +34,9 @@ def deform_conv2d( padding (int or Tuple[int, int]): height/width of padding of zeroes around each image. Default: 0 dilation (int or Tuple[int, int]): the spacing between kernel elements. Default: 1 + mask (Tensor[batch_size, offset_groups * kernel_height * kernel_width, + out_height, out_width]): masks to be applied for each position in the + convolution kernel. Returns: output (Tensor[batch_sz, out_channels, out_h, out_w]): result of convolution @@ -42,11 +46,12 @@ def deform_conv2d( >>> input = torch.rand(4, 3, 10, 10) >>> kh, kw = 3, 3 >>> weight = torch.rand(5, 3, kh, kw) - >>> # offset should have the same spatial size as the output + >>> # offset and mask should have the same spatial size as the output >>> # of the convolution. In this case, for an input of 10, stride of 1 >>> # and kernel size of 3, without padding, the output size is 8 >>> offset = torch.rand(4, 2 * kh * kw, 8, 8) - >>> out = deform_conv2d(input, offset, weight) + >>> mask = torch.rand(4, kh * kw, 8, 8) + >>> out = deform_conv2d(input, offset, weight, mask=mask) >>> print(out.shape) >>> # returns >>> torch.Size([4, 5, 8, 8]) @@ -54,6 +59,12 @@ def deform_conv2d( _assert_has_ops() out_channels = weight.shape[0] + + use_mask = mask is not None + + if mask is None: + mask = torch.zeros((input.shape[0], 0), device=input.device, dtype=input.dtype) + if bias is None: bias = torch.zeros(out_channels, device=input.device, dtype=input.dtype) @@ -77,18 +88,21 @@ def deform_conv2d( input, weight, offset, + mask, bias, stride_h, stride_w, pad_h, pad_w, dil_h, dil_w, n_weight_grps, - n_offset_grps) + n_offset_grps, + use_mask,) class DeformConv2d(nn.Module): """ See deform_conv2d """ + def __init__( self, in_channels: int, @@ -127,21 +141,25 @@ def __init__( def reset_parameters(self) -> None: init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) - def forward(self, input: Tensor, offset: Tensor) -> Tensor: + def forward(self, input: Tensor, offset: Tensor, mask: Tensor = None) -> Tensor: """ Arguments: input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width, out_height, out_width]): offsets to be applied for each position in the convolution kernel. + mask (Tensor[batch_size, offset_groups * kernel_height * kernel_width, + out_height, out_width]): masks to be applied for each position in the + convolution kernel. """ return deform_conv2d(input, offset, self.weight, self.bias, stride=self.stride, - padding=self.padding, dilation=self.dilation) + padding=self.padding, dilation=self.dilation, mask=mask) def __repr__(self) -> str: s = self.__class__.__name__ + '('