Skip to content

Commit

Permalink
Add modulation input for DeformConv2D
Browse files Browse the repository at this point in the history
  • Loading branch information
Licht-T committed Nov 8, 2020
1 parent 052edce commit 2285693
Show file tree
Hide file tree
Showing 8 changed files with 607 additions and 145 deletions.
73 changes: 56 additions & 17 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def test_new_empty_tensor(self):


class DeformConvTester(OpTester, unittest.TestCase):
def expected_fn(self, x, weight, offset, bias, stride=1, padding=0, dilation=1):
def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
stride_h, stride_w = _pair(stride)
pad_h, pad_w = _pair(padding)
dil_h, dil_w = _pair(dilation)
Expand Down Expand Up @@ -489,12 +489,17 @@ def expected_fn(self, x, weight, offset, bias, stride=1, padding=0, dilation=1):
c_in = weight_grp * in_c_per_weight_grp + c

offset_grp = c_in // in_c_per_offset_grp
offset_idx = 2 * (offset_grp * (weight_h * weight_w) + di * weight_w + dj)
mask_idx = offset_grp * (weight_h * weight_w) + di * weight_w + dj
offset_idx = 2 * mask_idx

pi = stride_h * i - pad_h + dil_h * di + offset[b, offset_idx, i, j]
pj = stride_w * j - pad_w + dil_w * dj + offset[b, offset_idx + 1, i, j]

out[b, c_out, i, j] += (weight[c_out, c, di, dj] *
mask_value = 1.0
if mask is not None:
mask_value = mask[b, mask_idx, i, j]

out[b, c_out, i, j] += (mask_value * weight[c_out, c, di, dj] *
bilinear_interpolate(x[b, c_in, :, :], pi, pj))
out += bias.view(1, n_out_channels, 1, 1)
return out
Expand Down Expand Up @@ -523,6 +528,9 @@ def get_fn_args(self, device, contiguous, batch_sz, dtype):
offset = torch.randn(batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w,
device=device, dtype=dtype, requires_grad=True)

mask = torch.randn(batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w,
device=device, dtype=dtype, requires_grad=True)

weight = torch.randn(n_out_channels, n_in_channels // n_weight_grps, weight_h, weight_w,
device=device, dtype=dtype, requires_grad=True)

Expand All @@ -531,31 +539,39 @@ def get_fn_args(self, device, contiguous, batch_sz, dtype):
if not contiguous:
x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)

return x, weight, offset, bias, stride, pad, dilation
return x, weight, offset, mask, bias, stride, pad, dilation

def _test_forward(self, device, contiguous, dtype=None):
dtype = self.dtype if dtype is None else dtype
for batch_sz in [0, 33]:
self._test_forward_with_batchsize(device, contiguous, batch_sz, dtype)

def _test_forward_with_batchsize(self, device, contiguous, batch_sz, dtype):
x, _, offset, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
in_channels = 6
out_channels = 2
kernel_size = (3, 2)
groups = 2
tol = 1e-3 if dtype is torch.half else 1e-5

layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups).to(device=x.device, dtype=dtype)
res = layer(x, offset)
res = layer(x, offset, mask)

weight = layer.weight.data
bias = layer.bias.data
expected = self.expected_fn(x, weight, offset, bias, stride=stride, padding=padding, dilation=dilation)
expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation)

self.assertTrue(torch.allclose(res.to(expected.dtype), expected, rtol=tol, atol=tol),
'\nres:\n{}\nexpected:\n{}'.format(res, expected))

# no modulation test
res = layer(x, offset)
expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation)

tol = 1e-3 if dtype is torch.half else 1e-5
self.assertTrue(torch.allclose(res.to(expected.dtype), expected, rtol=tol, atol=tol),
'\nres:\n{}\nexpected:\n{}'.format(res, expected))

Expand All @@ -564,24 +580,45 @@ def _test_forward_with_batchsize(self, device, contiguous, batch_sz, dtype):
wrong_offset = torch.rand_like(offset[:, :2])
res = layer(x, wrong_offset)

with self.assertRaises(RuntimeError):
wrong_mask = torch.rand_like(mask[:, :2])
res = layer(x, offset, wrong_mask)

def _test_backward(self, device, contiguous):
for batch_sz in [0, 33]:
self._test_backward_with_batchsize(device, contiguous, batch_sz)

def _test_backward_with_batchsize(self, device, contiguous, batch_sz):
x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, self.dtype)
x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, self.dtype)

def func(x_, offset_, mask_, weight_, bias_):
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride,
padding=padding, dilation=dilation, mask=mask_)

def func(x_, offset_, weight_, bias_):
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation)
gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5)

def func_no_mask(x_, offset_, weight_, bias_):
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride,
padding=padding, dilation=dilation, mask=None)

gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5)

@torch.jit.script
def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_):
# type:(Tensor, Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_,
padding=pad_, dilation=dilation_, mask=mask_)

gradcheck(func, (x, offset, weight, bias), nondet_tol=1e-5)
gradcheck(lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation),
(x, offset, mask, weight, bias), nondet_tol=1e-5)

@torch.jit.script
def script_func(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
# type: (Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_)
def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
# type:(Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_,
padding=pad_, dilation=dilation_, mask=None)

gradcheck(lambda z, off, wei, bi: script_func(z, off, wei, bi, stride, padding, dilation),
gradcheck(lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation),
(x, offset, weight, bias), nondet_tol=1e-5)

# Test from https://github.com/pytorch/vision/issues/2598
Expand All @@ -593,17 +630,19 @@ def script_func(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
init_weight = torch.randn(9, 9, 3, 3, requires_grad=True)
img = torch.randn(8, 9, 1000, 110)
offset = torch.rand(8, 2 * 3 * 3, 1000, 110)
mask = torch.rand(8, 3 * 3, 1000, 110)

if not contiguous:
img = img.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
weight = init_weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
else:
weight = init_weight

for d in ["cpu", "cuda"]:

out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1)
out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1, mask=mask.to(d))
out.mean().backward()
if true_cpu_grads is None:
true_cpu_grads = init_weight.grad
Expand Down
Loading

0 comments on commit 2285693

Please sign in to comment.