Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Deformable Convolution operation. #1586

Merged
merged 9 commits into from
Dec 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 157 additions & 36 deletions test/test_ops.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from __future__ import division
import math
import unittest

import numpy as np

import torch
from torch import Tensor
from torch.autograd import gradcheck

from torch.jit.annotations import Tuple
from torch.nn.modules.utils import _pair
from torchvision import ops

from itertools import product
import unittest


class RoIOpTester(object):
class OpTester(object):
@classmethod
def setUpClass(cls):
cls.dtype = torch.float64
Expand Down Expand Up @@ -42,6 +45,14 @@ def test_backward_cuda_contiguous(self):
def test_backward_cuda_non_contiguous(self):
self._test_backward(device=torch.device('cuda'), contiguous=False)

def _test_forward(self, device, contiguous):
pass

def _test_backward(self, device, contiguous):
pass


class RoIOpTester(OpTester):
def _test_forward(self, device, contiguous):
pool_size = 5
# n_channels % (pool_size ** 2) == 0 required for PS opeartions.
Expand Down Expand Up @@ -79,7 +90,6 @@ def func(z):

self.assertTrue(gradcheck(func, (x,)))
self.assertTrue(gradcheck(script_func, (x,)))
return

def fn(*args, **kwargs):
pass
Expand All @@ -98,7 +108,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
def get_script_fn(self, rois, pool_size):
@torch.jit.script
def script_fn(input, rois, pool_size):
# type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
# type: (Tensor, Tensor, int) -> Tensor
return ops.roi_pool(input, rois, pool_size, 1.0)[0]
return lambda x: script_fn(x, rois, pool_size)

Expand Down Expand Up @@ -137,7 +147,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
def get_script_fn(self, rois, pool_size):
@torch.jit.script
def script_fn(input, rois, pool_size):
# type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
# type: (Tensor, Tensor, int) -> Tensor
return ops.ps_roi_pool(input, rois, pool_size, 1.0)[0]
return lambda x: script_fn(x, rois, pool_size)

Expand Down Expand Up @@ -174,29 +184,35 @@ def get_slice(k, block):
return y


def bilinear_interpolate(data, height, width, y, x):
if y < -1.0 or y > height or x < -1.0 or x > width:
return 0.
def bilinear_interpolate(data, y, x, snap_border=False):
height, width = data.shape

y = min(max(0, y), height - 1)
x = min(max(0, x), width - 1)
if snap_border:
if -1 < y <= 0:
y = 0
elif height - 1 <= y < height:
y = height - 1

y_low = int(y)
y_high = min(y_low + 1, height - 1)
if -1 < x <= 0:
x = 0
elif width - 1 <= x < width:
x = width - 1

x_low = int(x)
x_high = min(x_low + 1, width - 1)
y_low = int(math.floor(y))
x_low = int(math.floor(x))
y_high = y_low + 1
x_high = x_low + 1

wy_h = y - y_low
wy_l = 1 - wy_h

wx_h = x - x_low
wy_l = 1 - wy_h
wx_l = 1 - wx_h

val = 0
for wx, x in zip((wx_l, wx_h), (x_low, x_high)):
for wy, y in zip((wy_l, wy_h), (y_low, y_high)):
val += wx * wy * data[y * width + x]
for wx, xp in zip((wx_l, wx_h), (x_low, x_high)):
for wy, yp in zip((wy_l, wy_h), (y_low, y_high)):
if 0 <= yp < height and 0 <= xp < width:
val += wx * wy * data[yp, xp]
return val


Expand All @@ -208,7 +224,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
def get_script_fn(self, rois, pool_size):
@torch.jit.script
def script_fn(input, rois, pool_size):
# type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
# type: (Tensor, Tensor, int) -> Tensor
return ops.roi_align(input, rois, pool_size, 1.0)[0]
return lambda x: script_fn(x, rois, pool_size)

Expand Down Expand Up @@ -242,12 +258,7 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_r
y = start_h + (iy + 0.5) * bin_h / grid_h
for ix in range(0, grid_w):
x = start_w + (ix + 0.5) * bin_w / grid_w
val += bilinear_interpolate(
in_data[batch_idx, channel, :, :].flatten(),
in_data.size(-2),
in_data.size(-1),
y, x
)
val += bilinear_interpolate(in_data[batch_idx, channel, :, :], y, x, snap_border=True)
val /= grid_h * grid_w

out_data[r, channel, i, j] = val
Expand All @@ -262,7 +273,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
def get_script_fn(self, rois, pool_size):
@torch.jit.script
def script_fn(input, rois, pool_size):
# type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
# type: (Tensor, Tensor, int) -> Tensor
return ops.ps_roi_align(input, rois, pool_size, 1.0)[0]
return lambda x: script_fn(x, rois, pool_size)

Expand Down Expand Up @@ -298,12 +309,7 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, device, spatial_scale=1,
y = start_h + (iy + 0.5) * bin_h / grid_h
for ix in range(0, grid_w):
x = start_w + (ix + 0.5) * bin_w / grid_w
val += bilinear_interpolate(
in_data[batch_idx, c_in, :, :].flatten(),
in_data.size(-2),
in_data.size(-1),
y, x
)
val += bilinear_interpolate(in_data[batch_idx, c_in, :, :], y, x, snap_border=True)
val /= grid_h * grid_w

out_data[r, c_out, i, j] = val
Expand Down Expand Up @@ -376,5 +382,120 @@ def test_new_empty_tensor(self):
assert out.dtype == input.dtype


class DeformConvTester(OpTester, unittest.TestCase):
def expected_fn(self, x, weight, offset, bias, stride=1, padding=0, dilation=1):
stride_h, stride_w = _pair(stride)
pad_h, pad_w = _pair(padding)
dil_h, dil_w = _pair(dilation)
weight_h, weight_w = weight.shape[-2:]

n_batches, n_in_channels, in_h, in_w = x.shape
n_out_channels = weight.shape[0]

out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1

n_offset_grps = offset.shape[1] // (2 * weight_h * weight_w)
in_c_per_offset_grp = n_in_channels // n_offset_grps

n_weight_grps = n_in_channels // weight.shape[1]
in_c_per_weight_grp = weight.shape[1]
out_c_per_weight_grp = n_out_channels // n_weight_grps

out = torch.zeros(n_batches, n_out_channels, out_h, out_w, device=x.device, dtype=x.dtype)
for b in range(n_batches):
for c_out in range(n_out_channels):
for i in range(out_h):
for j in range(out_w):
for di in range(weight_h):
for dj in range(weight_w):
for c in range(in_c_per_weight_grp):
weight_grp = c_out // out_c_per_weight_grp
c_in = weight_grp * in_c_per_weight_grp + c

offset_grp = c_in // in_c_per_offset_grp
offset_idx = 2 * (offset_grp * (weight_h * weight_w) + di * weight_w + dj)

pi = stride_h * i - pad_h + dil_h * di + offset[b, offset_idx, i, j]
pj = stride_w * j - pad_w + dil_w * dj + offset[b, offset_idx + 1, i, j]

out[b, c_out, i, j] += (weight[c_out, c, di, dj] *
bilinear_interpolate(x[b, c_in, :, :], pi, pj))
out += bias.view(1, n_out_channels, 1, 1)
return out

def get_fn_args(self, device, contiguous):
batch_sz = 1
n_in_channels = 6
n_out_channels = 2
n_weight_grps = 2
n_offset_grps = 3

stride = (2, 1)
pad = (1, 0)
dilation = (2, 1)

stride_h, stride_w = stride
pad_h, pad_w = pad
dil_h, dil_w = dilation
weight_h, weight_w = (3, 2)
in_h, in_w = (5, 4)

out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1

x = torch.rand(batch_sz, n_in_channels, in_h, in_w, device=device, dtype=self.dtype, requires_grad=True)

offset = torch.randn(batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w,
device=device, dtype=self.dtype, requires_grad=True)

weight = torch.randn(n_out_channels, n_in_channels // n_weight_grps, weight_h, weight_w,
device=device, dtype=self.dtype, requires_grad=True)

bias = torch.randn(n_out_channels, device=device, dtype=self.dtype, requires_grad=True)

if not contiguous:
x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)

return x, weight, offset, bias, stride, pad, dilation

def _test_forward(self, device, contiguous):
x, _, offset, _, stride, padding, dilation = self.get_fn_args(device, contiguous)
in_channels = 6
out_channels = 2
kernel_size = (3, 2)
groups = 2
offset_groups = 3

layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, offset_groups=offset_groups).to(device=x.device,
dtype=x.dtype)
res = layer(x, offset)

weight = layer.weight.data
bias = layer.bias.data
expected = self.expected_fn(x, weight, offset, bias, stride=stride, padding=padding, dilation=dilation)

self.assertTrue(torch.allclose(res, expected), '\nres:\n{}\nexpected:\n{}'.format(res, expected))

def _test_backward(self, device, contiguous):
x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous)

def func(x_, offset_, weight_, bias_):
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation)

gradcheck(func, (x, offset, weight, bias), nondet_tol=1e-5)

@torch.jit.script
def script_func(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
# type: (Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_)

gradcheck(lambda z, off, wei, bi: script_func(z, off, wei, bi, stride, padding, dilation),
(x, offset, weight, bias), nondet_tol=1e-5)


if __name__ == '__main__':
unittest.main()
Loading