From 38c9eb1a7876eb1ff837424e72d5a4870bdf7e1c Mon Sep 17 00:00:00 2001 From: Ritwik Das Date: Thu, 4 Feb 2021 20:49:07 -0800 Subject: [PATCH] Fix Bug in Bilinear Interpolation and Add Deform Conv to PT FrontEnd (#7397) * Fix Bug in Bilinear Interpolation * Add NHWC Tests * clean * Fix Bug and Add Deformable Conv PyTorch for completeness * Add Tensor Utils * Remove stuff * Include vector * PR Comments * Empty Commit for CI Co-authored-by: Ubuntu --- include/tvm/topi/detail/tensor_utils.h | 95 +++++++++++-------- python/tvm/relay/frontend/pytorch.py | 27 ++++++ .../topi/testing/deformable_conv2d_python.py | 26 +++-- python/tvm/topi/testing/roi_align_python.py | 34 ++++--- python/tvm/topi/vision/rcnn/roi_align.py | 4 +- tests/python/frontend/pytorch/test_forward.py | 88 ++++++++++++++++- tests/python/relay/test_op_level5.py | 71 ++++++++++---- 7 files changed, 257 insertions(+), 88 deletions(-) diff --git a/include/tvm/topi/detail/tensor_utils.h b/include/tvm/topi/detail/tensor_utils.h index 65a760b1397c..397c70c9451e 100644 --- a/include/tvm/topi/detail/tensor_utils.h +++ b/include/tvm/topi/detail/tensor_utils.h @@ -26,6 +26,7 @@ #include +#include namespace tvm { namespace topi { namespace detail { @@ -64,29 +65,36 @@ inline bool is_empty_shape(const Array& x) { */ inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array& indices, const PrimExpr max_y, const PrimExpr max_x) { + auto batch_id = indices[0]; + auto channel_id = indices[1]; auto in_y = indices[2]; - auto yf = tvm::floor(in_y); - auto yc = tvm::cast(DataType::Int(32), tvm::ceil(in_y)); - - auto y0 = tvm::cast(DataType::Int(32), tvm::floor(in_y)); - auto y1 = tvm::if_then_else((yc > max_y), max_y, yc); - auto y_lerp = in_y - yf; - auto in_x = indices[3]; - auto xf = tvm::floor(in_x); - auto xc = tvm::cast(DataType::Int(32), tvm::ceil(in_x)); - - auto x0 = tvm::cast(DataType::Int(32), tvm::floor(in_x)); - auto x1 = tvm::if_then_else((xc > max_x), max_x, xc); - auto x_lerp = in_x - xf; - auto A = input(indices[0], indices[1], y0, x0); - auto B = input(indices[0], indices[1], y0, x1); - auto C = input(indices[0], indices[1], y1, x0); - auto D = input(indices[0], indices[1], y1, x1); - - return A * (1 - x_lerp) * (1 - y_lerp) + B * x_lerp * (1 - y_lerp) + C * (1 - x_lerp) * y_lerp + - D * x_lerp * y_lerp; + auto y_low = tvm::cast(DataType::Int(32), tvm::floor(in_y)); + auto y_high = y_low + 1; + + auto x_low = tvm::cast(DataType::Int(32), tvm::floor(in_x)); + auto x_high = x_low + 1; + + auto wy_h = in_y - y_low; + auto wx_h = in_x - x_low; + auto wy_l = 1 - wy_h; + auto wx_l = 1 - wx_h; + + PrimExpr val = 0; + std::vector> wx_xp{{wx_l, x_low}, {wx_h, x_high}}; + std::vector> wy_yp{{wy_l, y_low}, {wy_h, y_high}}; + for (auto wx_xp_ele : wx_xp) { + for (auto wy_yp_ele : wy_yp) { + auto wx = wx_xp_ele[0]; + auto xp = wx_xp_ele[1]; + auto wy = wy_yp_ele[0]; + auto yp = wy_yp_ele[1]; + val += tvm::if_then_else(0 <= yp && yp <= max_y && 0 <= xp && xp <= max_x, + wx * wy * input(batch_id, channel_id, yp, xp), 0); + } + } + return val; } /*! @@ -101,29 +109,36 @@ inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array& */ inline PrimExpr bilinear_sample_nhwc(const Tensor& input, const Array& indices, const PrimExpr max_y, const PrimExpr max_x) { + auto batch_id = indices[0]; + auto channel_id = indices[3]; auto in_y = indices[1]; - auto yf = tvm::floor(in_y); - auto yc = tvm::cast(DataType::Int(32), tvm::ceil(in_y)); - - auto y0 = tvm::cast(DataType::Int(32), tvm::floor(in_y)); - auto y1 = tvm::if_then_else((yc > max_y), max_y, yc); - auto y_lerp = in_y - yf; - auto in_x = indices[2]; - auto xf = tvm::floor(in_x); - auto xc = tvm::cast(DataType::Int(32), tvm::ceil(in_x)); - - auto x0 = tvm::cast(DataType::Int(32), tvm::floor(in_x)); - auto x1 = tvm::if_then_else((xc > max_x), max_x, xc); - auto x_lerp = in_x - xf; - auto A = input(indices[0], y0, x0, indices[3]); - auto B = input(indices[0], y0, x1, indices[3]); - auto C = input(indices[0], y1, x0, indices[3]); - auto D = input(indices[0], y1, x1, indices[3]); - - return A * (1 - x_lerp) * (1 - y_lerp) + B * x_lerp * (1 - y_lerp) + C * (1 - x_lerp) * y_lerp + - D * x_lerp * y_lerp; + auto y_low = tvm::cast(DataType::Int(32), tvm::floor(in_y)); + auto y_high = y_low + 1; + + auto x_low = tvm::cast(DataType::Int(32), tvm::floor(in_x)); + auto x_high = x_low + 1; + + auto wy_h = in_y - y_low; + auto wx_h = in_x - x_low; + auto wy_l = 1 - wy_h; + auto wx_l = 1 - wx_h; + + PrimExpr val = 0; + std::vector> wx_xp{{wx_l, x_low}, {wx_h, x_high}}; + std::vector> wy_yp{{wy_l, y_low}, {wy_h, y_high}}; + for (auto wx_xp_ele : wx_xp) { + for (auto wy_yp_ele : wy_yp) { + auto wx = wx_xp_ele[0]; + auto xp = wx_xp_ele[1]; + auto wy = wy_yp_ele[0]; + auto yp = wy_yp_ele[1]; + val += tvm::if_then_else(0 <= yp && yp <= max_y && 0 <= xp && xp <= max_x, + wx * wy * input(batch_id, yp, xp, channel_id), 0); + } + } + return val; } } // namespace detail diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 68e68fdbeed2..246ed97b14e9 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1928,6 +1928,32 @@ def roi_align(self, inputs, input_types): return _op.vision.roi_align(data, boxes, output_size, spatial_scale, sample_ratio) + def deform_conv2d(self, inputs, input_types): + data = inputs[0] + weight = inputs[1] + offset = inputs[2] + strides = (inputs[4], inputs[5]) + padding = (inputs[6], inputs[7]) + dilation = (inputs[8], inputs[9]) + groups = inputs[10] + deformable_groups = inputs[11] + weight_shape = self.infer_shape(weight) + output_channels = weight_shape[0] + kernel_size = (weight_shape[2], weight_shape[3]) + + return _op.nn.deformable_conv2d( + data, + offset, + weight, + strides, + padding, + dilation, + deformable_groups, + groups, + output_channels, + kernel_size, + ) + def unbind(self, inputs, input_types): data = inputs[0] dim = int(inputs[1]) @@ -2292,6 +2318,7 @@ def create_convert_map(self): "torchvision::nms": self.nms, "aten::logsumexp": self.logsumexp, "torchvision::roi_align": self.roi_align, + "torchvision::deform_conv2d": self.deform_conv2d, "aten::unbind": self.unbind, "aten::__and__": self.logical_and, "aten::logical_and": self.logical_and, diff --git a/python/tvm/topi/testing/deformable_conv2d_python.py b/python/tvm/topi/testing/deformable_conv2d_python.py index 093084397ff1..758a70eb4cc1 100644 --- a/python/tvm/topi/testing/deformable_conv2d_python.py +++ b/python/tvm/topi/testing/deformable_conv2d_python.py @@ -17,6 +17,7 @@ # pylint: disable=invalid-name, too-many-locals, too-many-arguments """Deformable convolution in python""" import itertools +import math import numpy as np from tvm.topi.nn.utils import get_pad_tuple @@ -80,15 +81,22 @@ def deformable_conv2d_nchw_python( dilation_h, dilation_w = dilation def _bilinear(n, c, h, w): - low_h, low_w = int(h), int(w) - high_h = min(low_h + 1, in_height - 1) - high_w = min(low_w + 1, in_width - 1) - y_lerp = h - low_h - x_lerp = w - low_w - - bottom = (1 - x_lerp) * a_np[n, c, low_h, low_w] + x_lerp * a_np[n, c, low_h, high_w] - top = (1 - x_lerp) * a_np[n, c, high_h, low_w] + x_lerp * a_np[n, c, high_h, high_w] - return (1 - y_lerp) * bottom + y_lerp * top + y_low = int(math.floor(h)) + x_low = int(math.floor(w)) + y_high = y_low + 1 + x_high = x_low + 1 + + wy_h = h - y_low + wx_h = w - x_low + wy_l = 1 - wy_h + wx_l = 1 - wx_h + + val = 0 + for wx, xp in zip((wx_l, wx_h), (x_low, x_high)): + for wy, yp in zip((wy_l, wy_h), (y_low, y_high)): + if 0 <= yp < in_height and 0 <= xp < in_width: + val += wx * wy * a_np[n, c, yp, xp] + return val a_deform = np.zeros((batch, in_channel, out_height, out_width, kernel_h, kernel_w), dtype=dtype) for n, h, w in itertools.product(range(batch), range(out_height), range(out_width)): diff --git a/python/tvm/topi/testing/roi_align_python.py b/python/tvm/topi/testing/roi_align_python.py index 5bb292c46fbb..abef25f0b994 100644 --- a/python/tvm/topi/testing/roi_align_python.py +++ b/python/tvm/topi/testing/roi_align_python.py @@ -31,25 +31,29 @@ def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_rati else: pooled_size_h, pooled_size_w = pooled_size - def _bilinear(b, c, y, x): + def _bilinear(n, c, y, x): if y < -1 or y > height or x < -1 or x > width: return 0 - y = max(y, 0.0) - x = max(x, 0.0) - y_low = int(y) - x_low = int(x) - y_high = min(y_low + 1, height - 1) - x_high = min(x_low + 1, width - 1) + y = min(max(y, 0), height - 1) + x = min(max(x, 0), width - 1) - ly = y - y_low - lx = x - x_low - return ( - (1 - ly) * (1 - lx) * a_np[b, c, y_low, x_low] - + (1 - ly) * lx * a_np[b, c, y_low, x_high] - + ly * (1 - lx) * a_np[b, c, y_high, x_low] - + ly * lx * a_np[b, c, y_high, x_high] - ) + y_low = int(math.floor(y)) + x_low = int(math.floor(x)) + y_high = y_low + 1 + x_high = x_low + 1 + + wy_h = y - y_low + wx_h = x - x_low + wy_l = 1 - wy_h + wx_l = 1 - wx_h + + val = 0 + for wx, xp in zip((wx_l, wx_h), (x_low, x_high)): + for wy, yp in zip((wy_l, wy_h), (y_low, y_high)): + if 0 <= yp < height and 0 <= xp < width: + val += wx * wy * a_np[n, c, yp, xp] + return val for i in range(num_roi): roi = rois_np[i] diff --git a/python/tvm/topi/vision/rcnn/roi_align.py b/python/tvm/topi/vision/rcnn/roi_align.py index a51ba33a6c45..30824770b7b2 100644 --- a/python/tvm/topi/vision/rcnn/roi_align.py +++ b/python/tvm/topi/vision/rcnn/roi_align.py @@ -60,8 +60,8 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1): def _bilinear(i, c, y, x): outside = tvm.tir.any(y < -1.0, x < -1.0, y > height, x > width) - y = tvm.te.max(y, 0.0) - x = tvm.te.max(x, 0.0) + y = tvm.te.min(tvm.te.max(y, 0.0), height - 1) + x = tvm.te.min(tvm.te.max(x, 0.0), width - 1) val = bilinear_sample_nchw(data, (i, c, y, x), height - 1, width - 1) return tvm.tir.if_then_else(outside, 0.0, val) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 6d9b559c6ba1..8d968e9760c9 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -216,7 +216,6 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at assert_shapes_match(baseline_output, compiled_output) tvm.testing.assert_allclose(baseline_output, compiled_output, rtol=rtol, atol=atol) - del model_name del baseline_model torch.cuda.empty_cache() @@ -924,6 +923,85 @@ def test_forward_conv_transpose(): verify_model(torch.nn.ConvTranspose1d(3, 12, 3, bias=False), input_data=conv1d_input_data) +def test_forward_deform_conv(): + torch.set_grad_enabled(False) + + def test_run( + batch_size, + in_channels, + out_channels, + in_height, + in_width, + out_height, + out_width, + offset_groups, + kh, + kw, + groups, + ): + input_shape = [batch_size, in_channels, in_height, in_width] + offset_shape = [batch_size, 2 * offset_groups * kh * kw, out_height, out_width] + weight_shape = [out_channels, in_channels // groups, kh, kw] + input_data = torch.rand(input_shape) + offset_data = torch.rand(offset_shape) + weight_data = torch.rand(weight_shape) + + class DeformConv2D(Module): + def forward(self, *args): + return torchvision.ops.deform_conv2d(args[0], args[1], args[2]) + + verify_model( + DeformConv2D().float().eval(), + input_data=[input_data, offset_data, weight_data], + rtol=1e-4, + atol=1e-4, + ) + + batch_size = 4 + in_channels, out_channels = 4, 6 + in_height, in_width = 10, 10 + out_height, out_width = 8, 8 + offset_groups = 2 + kh, kw = 3, 3 + groups = 1 + + test_run( + batch_size, + in_channels, + out_channels, + in_height, + in_width, + out_height, + out_width, + offset_groups, + kh, + kw, + groups, + ) + + batch_size = 5 + in_channels, out_channels = 4, 6 + in_height, in_width = 10, 10 + out_height, out_width = 8, 8 + offset_groups = 1 + kh, kw = 3, 3 + groups = 1 + + test_run( + batch_size, + in_channels, + out_channels, + in_height, + in_width, + out_height, + out_width, + offset_groups, + kh, + kw, + groups, + ) + + @tvm.testing.uses_gpu def test_forward_threshold(): torch.set_grad_enabled(False) @@ -1700,7 +1778,7 @@ def test_forward_roi_align(): """ROI align""" torch.set_grad_enabled(False) - class ROIAlgin(Module): + class ROIAlign(Module): def __init__(self, output_sizes, spatial_scale=1.0, sampling_ratio=-1): super().__init__() self.spatial_scale = spatial_scale @@ -1721,9 +1799,9 @@ def forward(self, *args): in_batch = torch.zeros((35, 1), dtype=torch.float) in_boxes = torch.cat([in_batch, in_boxes], dim=1) - verify_model(ROIAlgin(7), [in_data, in_boxes]) - verify_model(ROIAlgin((10, 10), 0.7, 5), [in_data, in_boxes]) - verify_model(ROIAlgin(15, 0.9, 3), [in_data, in_boxes]) + verify_model(ROIAlign(7), [in_data, in_boxes]) + verify_model(ROIAlign((10, 10), 0.7, 5), [in_data, in_boxes]) + verify_model(ROIAlign(15, 0.9, 3), [in_data, in_boxes]) @tvm.testing.uses_gpu diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index cdf3b240507b..6d7d401d706b 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -837,11 +837,31 @@ def test_infer_type(batch, in_channel, size, out_channel, deformable_groups, gro test_infer_type(1, 4, 16, 4, 4, 1, "NHWC") test_infer_type(2, 4, 16, 4, 1, 2, "NHWC") - def test_run(batch, in_channel, size, out_channel, deformable_groups, groups): + def test_run(batch, in_channel, size, out_channel, deformable_groups, groups, layout): kernel_size = (3, 3) - data_shape = (batch, in_channel, size, size) - offset_shape = (batch, 2 * kernel_size[0] * kernel_size[1] * deformable_groups, size, size) - kernel_shape = (out_channel, in_channel // groups, kernel_size[0], kernel_size[1]) + if layout == "NCHW": + kernel_layout = "OIHW" + data_shape = (batch, in_channel, size, size) + kernel_shape = (out_channel, in_channel // groups, kernel_size[0], kernel_size[1]) + out_shape = (batch, out_channel, size, size) + offset_shape = ( + batch, + 2 * kernel_size[0] * kernel_size[1] * deformable_groups, + out_shape[2], + out_shape[3], + ) + else: + kernel_layout = "HWIO" + data_shape = (batch, size, size, in_channel) + kernel_shape = (kernel_size[0], kernel_size[1], in_channel // groups, out_channel) + out_shape = (batch, size, size, out_channel) + offset_shape = ( + batch, + out_shape[1], + out_shape[2], + 2 * kernel_size[0] * kernel_size[1] * deformable_groups, + ) + dtype = "float32" data = relay.var("data", shape=data_shape, dtype=dtype) offset = relay.var("offset") @@ -853,6 +873,8 @@ def test_run(batch, in_channel, size, out_channel, deformable_groups, groups): strides=(1, 1), padding=(1, 1), dilation=(1, 1), + data_layout=layout, + kernel_layout=kernel_layout, kernel_size=kernel_size, deformable_groups=deformable_groups, groups=groups, @@ -862,25 +884,40 @@ def test_run(batch, in_channel, size, out_channel, deformable_groups, groups): data = np.random.uniform(size=data_shape).astype(dtype) offset = np.random.uniform(size=offset_shape).astype(dtype) kernel = np.random.uniform(size=kernel_shape).astype(dtype) - ref_res = tvm.topi.testing.deformable_conv2d_nchw_python( - data, - offset, - kernel, - stride=(1, 1), - padding=(1, 1), - dilation=(1, 1), - deformable_groups=deformable_groups, - groups=groups, - ) - + if layout == "NCHW": + ref_res = tvm.topi.testing.deformable_conv2d_nchw_python( + data, + offset, + kernel, + stride=(1, 1), + padding=(1, 1), + dilation=(1, 1), + deformable_groups=deformable_groups, + groups=groups, + ) + else: + ref_res = tvm.topi.testing.deformable_conv2d_nhwc_python( + data, + offset, + kernel, + stride=(1, 1), + padding=(1, 1), + dilation=(1, 1), + deformable_groups=deformable_groups, + groups=groups, + ) for target, ctx in tvm.testing.enabled_targets(): + if target == "cuda" and layout == "NHWC": + continue # Cannot run NHWC layout on cuda target, only on llvm for kind in ["graph", "debug"]: intrp1 = relay.create_executor(kind, ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data, offset, kernel) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) - test_run(1, 4, 16, 4, 1, 1) - test_run(2, 4, 16, 4, 4, 1) + test_run(1, 4, 16, 4, 1, 1, "NCHW") + test_run(1, 4, 16, 4, 1, 1, "NHWC") + test_run(2, 4, 16, 4, 4, 1, "NCHW") + test_run(2, 4, 16, 4, 4, 1, "NHWC") @tvm.testing.uses_gpu