Skip to content

Commit

Permalink
Fix Bug in Bilinear Interpolation and Add Deform Conv to PT FrontEnd (#…
Browse files Browse the repository at this point in the history
…7397)

* Fix Bug in Bilinear Interpolation

* Add NHWC Tests

* clean

* Fix Bug and Add Deformable Conv PyTorch for completeness

* Add Tensor Utils

* Remove stuff

* Include vector

* PR Comments

* Empty Commit for CI

Co-authored-by: Ubuntu <ubuntu@ip-172-31-42-251.us-east-2.compute.internal>
  • Loading branch information
codeislife99 and Ubuntu authored Feb 5, 2021
1 parent c118b08 commit 38c9eb1
Show file tree
Hide file tree
Showing 7 changed files with 257 additions and 88 deletions.
95 changes: 55 additions & 40 deletions include/tvm/topi/detail/tensor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <tvm/te/operation.h>

#include <vector>
namespace tvm {
namespace topi {
namespace detail {
Expand Down Expand Up @@ -64,29 +65,36 @@ inline bool is_empty_shape(const Array<PrimExpr>& x) {
*/
inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array<PrimExpr>& indices,
const PrimExpr max_y, const PrimExpr max_x) {
auto batch_id = indices[0];
auto channel_id = indices[1];
auto in_y = indices[2];
auto yf = tvm::floor(in_y);
auto yc = tvm::cast(DataType::Int(32), tvm::ceil(in_y));

auto y0 = tvm::cast(DataType::Int(32), tvm::floor(in_y));
auto y1 = tvm::if_then_else((yc > max_y), max_y, yc);
auto y_lerp = in_y - yf;

auto in_x = indices[3];
auto xf = tvm::floor(in_x);
auto xc = tvm::cast(DataType::Int(32), tvm::ceil(in_x));

auto x0 = tvm::cast(DataType::Int(32), tvm::floor(in_x));
auto x1 = tvm::if_then_else((xc > max_x), max_x, xc);
auto x_lerp = in_x - xf;

auto A = input(indices[0], indices[1], y0, x0);
auto B = input(indices[0], indices[1], y0, x1);
auto C = input(indices[0], indices[1], y1, x0);
auto D = input(indices[0], indices[1], y1, x1);

return A * (1 - x_lerp) * (1 - y_lerp) + B * x_lerp * (1 - y_lerp) + C * (1 - x_lerp) * y_lerp +
D * x_lerp * y_lerp;
auto y_low = tvm::cast(DataType::Int(32), tvm::floor(in_y));
auto y_high = y_low + 1;

auto x_low = tvm::cast(DataType::Int(32), tvm::floor(in_x));
auto x_high = x_low + 1;

auto wy_h = in_y - y_low;
auto wx_h = in_x - x_low;
auto wy_l = 1 - wy_h;
auto wx_l = 1 - wx_h;

PrimExpr val = 0;
std::vector<std::vector<PrimExpr>> wx_xp{{wx_l, x_low}, {wx_h, x_high}};
std::vector<std::vector<PrimExpr>> wy_yp{{wy_l, y_low}, {wy_h, y_high}};
for (auto wx_xp_ele : wx_xp) {
for (auto wy_yp_ele : wy_yp) {
auto wx = wx_xp_ele[0];
auto xp = wx_xp_ele[1];
auto wy = wy_yp_ele[0];
auto yp = wy_yp_ele[1];
val += tvm::if_then_else(0 <= yp && yp <= max_y && 0 <= xp && xp <= max_x,
wx * wy * input(batch_id, channel_id, yp, xp), 0);
}
}
return val;
}

/*!
Expand All @@ -101,29 +109,36 @@ inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array<PrimExpr>&
*/
inline PrimExpr bilinear_sample_nhwc(const Tensor& input, const Array<PrimExpr>& indices,
const PrimExpr max_y, const PrimExpr max_x) {
auto batch_id = indices[0];
auto channel_id = indices[3];
auto in_y = indices[1];
auto yf = tvm::floor(in_y);
auto yc = tvm::cast(DataType::Int(32), tvm::ceil(in_y));

auto y0 = tvm::cast(DataType::Int(32), tvm::floor(in_y));
auto y1 = tvm::if_then_else((yc > max_y), max_y, yc);
auto y_lerp = in_y - yf;

auto in_x = indices[2];
auto xf = tvm::floor(in_x);
auto xc = tvm::cast(DataType::Int(32), tvm::ceil(in_x));

auto x0 = tvm::cast(DataType::Int(32), tvm::floor(in_x));
auto x1 = tvm::if_then_else((xc > max_x), max_x, xc);
auto x_lerp = in_x - xf;

auto A = input(indices[0], y0, x0, indices[3]);
auto B = input(indices[0], y0, x1, indices[3]);
auto C = input(indices[0], y1, x0, indices[3]);
auto D = input(indices[0], y1, x1, indices[3]);

return A * (1 - x_lerp) * (1 - y_lerp) + B * x_lerp * (1 - y_lerp) + C * (1 - x_lerp) * y_lerp +
D * x_lerp * y_lerp;
auto y_low = tvm::cast(DataType::Int(32), tvm::floor(in_y));
auto y_high = y_low + 1;

auto x_low = tvm::cast(DataType::Int(32), tvm::floor(in_x));
auto x_high = x_low + 1;

auto wy_h = in_y - y_low;
auto wx_h = in_x - x_low;
auto wy_l = 1 - wy_h;
auto wx_l = 1 - wx_h;

PrimExpr val = 0;
std::vector<std::vector<PrimExpr>> wx_xp{{wx_l, x_low}, {wx_h, x_high}};
std::vector<std::vector<PrimExpr>> wy_yp{{wy_l, y_low}, {wy_h, y_high}};
for (auto wx_xp_ele : wx_xp) {
for (auto wy_yp_ele : wy_yp) {
auto wx = wx_xp_ele[0];
auto xp = wx_xp_ele[1];
auto wy = wy_yp_ele[0];
auto yp = wy_yp_ele[1];
val += tvm::if_then_else(0 <= yp && yp <= max_y && 0 <= xp && xp <= max_x,
wx * wy * input(batch_id, yp, xp, channel_id), 0);
}
}
return val;
}

} // namespace detail
Expand Down
27 changes: 27 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1928,6 +1928,32 @@ def roi_align(self, inputs, input_types):

return _op.vision.roi_align(data, boxes, output_size, spatial_scale, sample_ratio)

def deform_conv2d(self, inputs, input_types):
data = inputs[0]
weight = inputs[1]
offset = inputs[2]
strides = (inputs[4], inputs[5])
padding = (inputs[6], inputs[7])
dilation = (inputs[8], inputs[9])
groups = inputs[10]
deformable_groups = inputs[11]
weight_shape = self.infer_shape(weight)
output_channels = weight_shape[0]
kernel_size = (weight_shape[2], weight_shape[3])

return _op.nn.deformable_conv2d(
data,
offset,
weight,
strides,
padding,
dilation,
deformable_groups,
groups,
output_channels,
kernel_size,
)

def unbind(self, inputs, input_types):
data = inputs[0]
dim = int(inputs[1])
Expand Down Expand Up @@ -2292,6 +2318,7 @@ def create_convert_map(self):
"torchvision::nms": self.nms,
"aten::logsumexp": self.logsumexp,
"torchvision::roi_align": self.roi_align,
"torchvision::deform_conv2d": self.deform_conv2d,
"aten::unbind": self.unbind,
"aten::__and__": self.logical_and,
"aten::logical_and": self.logical_and,
Expand Down
26 changes: 17 additions & 9 deletions python/tvm/topi/testing/deformable_conv2d_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=invalid-name, too-many-locals, too-many-arguments
"""Deformable convolution in python"""
import itertools
import math
import numpy as np
from tvm.topi.nn.utils import get_pad_tuple

Expand Down Expand Up @@ -80,15 +81,22 @@ def deformable_conv2d_nchw_python(
dilation_h, dilation_w = dilation

def _bilinear(n, c, h, w):
low_h, low_w = int(h), int(w)
high_h = min(low_h + 1, in_height - 1)
high_w = min(low_w + 1, in_width - 1)
y_lerp = h - low_h
x_lerp = w - low_w

bottom = (1 - x_lerp) * a_np[n, c, low_h, low_w] + x_lerp * a_np[n, c, low_h, high_w]
top = (1 - x_lerp) * a_np[n, c, high_h, low_w] + x_lerp * a_np[n, c, high_h, high_w]
return (1 - y_lerp) * bottom + y_lerp * top
y_low = int(math.floor(h))
x_low = int(math.floor(w))
y_high = y_low + 1
x_high = x_low + 1

wy_h = h - y_low
wx_h = w - x_low
wy_l = 1 - wy_h
wx_l = 1 - wx_h

val = 0
for wx, xp in zip((wx_l, wx_h), (x_low, x_high)):
for wy, yp in zip((wy_l, wy_h), (y_low, y_high)):
if 0 <= yp < in_height and 0 <= xp < in_width:
val += wx * wy * a_np[n, c, yp, xp]
return val

a_deform = np.zeros((batch, in_channel, out_height, out_width, kernel_h, kernel_w), dtype=dtype)
for n, h, w in itertools.product(range(batch), range(out_height), range(out_width)):
Expand Down
34 changes: 19 additions & 15 deletions python/tvm/topi/testing/roi_align_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,29 @@ def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_rati
else:
pooled_size_h, pooled_size_w = pooled_size

def _bilinear(b, c, y, x):
def _bilinear(n, c, y, x):
if y < -1 or y > height or x < -1 or x > width:
return 0
y = max(y, 0.0)
x = max(x, 0.0)
y_low = int(y)
x_low = int(x)

y_high = min(y_low + 1, height - 1)
x_high = min(x_low + 1, width - 1)
y = min(max(y, 0), height - 1)
x = min(max(x, 0), width - 1)

ly = y - y_low
lx = x - x_low
return (
(1 - ly) * (1 - lx) * a_np[b, c, y_low, x_low]
+ (1 - ly) * lx * a_np[b, c, y_low, x_high]
+ ly * (1 - lx) * a_np[b, c, y_high, x_low]
+ ly * lx * a_np[b, c, y_high, x_high]
)
y_low = int(math.floor(y))
x_low = int(math.floor(x))
y_high = y_low + 1
x_high = x_low + 1

wy_h = y - y_low
wx_h = x - x_low
wy_l = 1 - wy_h
wx_l = 1 - wx_h

val = 0
for wx, xp in zip((wx_l, wx_h), (x_low, x_high)):
for wy, yp in zip((wy_l, wy_h), (y_low, y_high)):
if 0 <= yp < height and 0 <= xp < width:
val += wx * wy * a_np[n, c, yp, xp]
return val

for i in range(num_roi):
roi = rois_np[i]
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/vision/rcnn/roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):

def _bilinear(i, c, y, x):
outside = tvm.tir.any(y < -1.0, x < -1.0, y > height, x > width)
y = tvm.te.max(y, 0.0)
x = tvm.te.max(x, 0.0)
y = tvm.te.min(tvm.te.max(y, 0.0), height - 1)
x = tvm.te.min(tvm.te.max(x, 0.0), width - 1)
val = bilinear_sample_nchw(data, (i, c, y, x), height - 1, width - 1)
return tvm.tir.if_then_else(outside, 0.0, val)

Expand Down
88 changes: 83 additions & 5 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at

assert_shapes_match(baseline_output, compiled_output)
tvm.testing.assert_allclose(baseline_output, compiled_output, rtol=rtol, atol=atol)

del model_name
del baseline_model
torch.cuda.empty_cache()
Expand Down Expand Up @@ -924,6 +923,85 @@ def test_forward_conv_transpose():
verify_model(torch.nn.ConvTranspose1d(3, 12, 3, bias=False), input_data=conv1d_input_data)


def test_forward_deform_conv():
torch.set_grad_enabled(False)

def test_run(
batch_size,
in_channels,
out_channels,
in_height,
in_width,
out_height,
out_width,
offset_groups,
kh,
kw,
groups,
):
input_shape = [batch_size, in_channels, in_height, in_width]
offset_shape = [batch_size, 2 * offset_groups * kh * kw, out_height, out_width]
weight_shape = [out_channels, in_channels // groups, kh, kw]
input_data = torch.rand(input_shape)
offset_data = torch.rand(offset_shape)
weight_data = torch.rand(weight_shape)

class DeformConv2D(Module):
def forward(self, *args):
return torchvision.ops.deform_conv2d(args[0], args[1], args[2])

verify_model(
DeformConv2D().float().eval(),
input_data=[input_data, offset_data, weight_data],
rtol=1e-4,
atol=1e-4,
)

batch_size = 4
in_channels, out_channels = 4, 6
in_height, in_width = 10, 10
out_height, out_width = 8, 8
offset_groups = 2
kh, kw = 3, 3
groups = 1

test_run(
batch_size,
in_channels,
out_channels,
in_height,
in_width,
out_height,
out_width,
offset_groups,
kh,
kw,
groups,
)

batch_size = 5
in_channels, out_channels = 4, 6
in_height, in_width = 10, 10
out_height, out_width = 8, 8
offset_groups = 1
kh, kw = 3, 3
groups = 1

test_run(
batch_size,
in_channels,
out_channels,
in_height,
in_width,
out_height,
out_width,
offset_groups,
kh,
kw,
groups,
)


@tvm.testing.uses_gpu
def test_forward_threshold():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -1700,7 +1778,7 @@ def test_forward_roi_align():
"""ROI align"""
torch.set_grad_enabled(False)

class ROIAlgin(Module):
class ROIAlign(Module):
def __init__(self, output_sizes, spatial_scale=1.0, sampling_ratio=-1):
super().__init__()
self.spatial_scale = spatial_scale
Expand All @@ -1721,9 +1799,9 @@ def forward(self, *args):
in_batch = torch.zeros((35, 1), dtype=torch.float)
in_boxes = torch.cat([in_batch, in_boxes], dim=1)

verify_model(ROIAlgin(7), [in_data, in_boxes])
verify_model(ROIAlgin((10, 10), 0.7, 5), [in_data, in_boxes])
verify_model(ROIAlgin(15, 0.9, 3), [in_data, in_boxes])
verify_model(ROIAlign(7), [in_data, in_boxes])
verify_model(ROIAlign((10, 10), 0.7, 5), [in_data, in_boxes])
verify_model(ROIAlign(15, 0.9, 3), [in_data, in_boxes])


@tvm.testing.uses_gpu
Expand Down
Loading

0 comments on commit 38c9eb1

Please sign in to comment.