From 5af2cf04916518a9b531883f8bb6f722e9c0327d Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 13 Apr 2021 01:56:38 -0700 Subject: [PATCH] [fbsync] Add Quantized version of RoIAlign (#3624) Summary: * WIP * clang * docs * extracted out common utils * Use better quantization function and pass tensors as parameters * proper dequantization * Some tests * Dequantization optimization, seems to gain a few ms * clang-format * again * more correct test. Had to remove optimization although it almost works * Also test aligned=True * remove useless part * more docs and comments * Put back optimization with more robust test * Added check for index upper bound * avoid possible overflow * Move common function into common.h * oops * scale=1,zero_point=0 makes more sense * Force batch size of 1 to prevent any indexingbug * format * format again * updated docstring * put back description comment for pre_calc_bilinear_interpolate * revert most changes to docstring as it's taken care of in another PR Reviewed By: NicolasHug Differential Revision: D27706946 fbshipit-source-id: 2ae1614c214ea676b4f7705dc0716efd9f34330e --- test/test_ops.py | 72 ++++++ torchvision/csrc/ops/cpu/roi_align_common.h | 128 +++++++++++ torchvision/csrc/ops/cpu/roi_align_kernel.cpp | 125 +---------- .../ops/quantized/cpu/qroi_align_kernel.cpp | 208 ++++++++++++++++++ torchvision/ops/roi_align.py | 1 + 5 files changed, 417 insertions(+), 117 deletions(-) create mode 100644 torchvision/csrc/ops/cpu/roi_align_common.h create mode 100644 torchvision/csrc/ops/quantized/cpu/qroi_align_kernel.cpp diff --git a/test/test_ops.py b/test/test_ops.py index 0031da45cce..8c63c9c29c6 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -299,6 +299,78 @@ def _test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwa for aligned in (True, False): super()._test_forward(device, contiguous, x_dtype, rois_dtype, aligned=aligned) + def test_qroialign(self): + """Make sure quantized version of RoIAlign is close to float version""" + pool_size = 5 + img_size = 10 + n_channels = 2 + num_imgs = 1 + dtype = torch.float + + def make_rois(num_rois=1000): + rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype) + rois[:, 0] = torch.randint(0, num_imgs, size=(num_rois,)) # set batch index + rois[:, 3:] += rois[:, 1:3] # make sure boxes aren't degenerate + return rois + + for aligned in (True, False): + for scale, zero_point in ((1, 0), (2, 10), (0.1, 50)): + for qdtype in (torch.qint8, torch.quint8, torch.qint32): + + x = torch.randint(50, 100, size=(num_imgs, n_channels, img_size, img_size)).to(dtype) + qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, dtype=qdtype) + + rois = make_rois() + qrois = torch.quantize_per_tensor(rois, scale=scale, zero_point=zero_point, dtype=qdtype) + + x, rois = qx.dequantize(), qrois.dequantize() # we want to pass the same inputs + + y = ops.roi_align( + x, + rois, + output_size=pool_size, + spatial_scale=1, + sampling_ratio=-1, + aligned=aligned, + ) + qy = ops.roi_align( + qx, + qrois, + output_size=pool_size, + spatial_scale=1, + sampling_ratio=-1, + aligned=aligned, + ) + + # The output qy is itself a quantized tensor and there might have been a loss of info when it was + # quantized. For a fair comparison we need to quantize y as well + quantized_float_y = torch.quantize_per_tensor(y, scale=scale, zero_point=zero_point, dtype=qdtype) + + try: + # Ideally, we would assert this, which passes with (scale, zero) == (1, 0) + self.assertTrue((qy == quantized_float_y).all()) + except AssertionError: + # But because the computation aren't exactly the same between the 2 RoIAlign procedures, some + # rounding error may lead to a difference of 2 in the output. + # For example with (scale, zero) = (2, 10), 45.00000... will be quantized to 44 + # but 45.00000001 will be rounded to 46. We make sure below that: + # - such discrepancies between qy and quantized_float_y are very rare (less then 5%) + # - any difference between qy and quantized_float_y is == scale + diff_idx = torch.where(qy != quantized_float_y) + num_diff = diff_idx[0].numel() + self.assertTrue(num_diff / qy.numel() < .05) + + abs_diff = torch.abs(qy[diff_idx].dequantize() - quantized_float_y[diff_idx].dequantize()) + t_scale = torch.full_like(abs_diff, fill_value=scale) + self.assertTrue(torch.allclose(abs_diff, t_scale, atol=1e-5)) + + x = torch.randint(50, 100, size=(2, 3, 10, 10)).to(dtype) + qx = torch.quantize_per_tensor(x, scale=1, zero_point=0, dtype=torch.qint8) + rois = make_rois(10) + qrois = torch.quantize_per_tensor(rois, scale=1, zero_point=0, dtype=torch.qint8) + with self.assertRaisesRegex(RuntimeError, "Only one image per batch is allowed"): + ops.roi_align(qx, qrois, output_size=pool_size) + class PSRoIAlignTester(RoIOpTester, unittest.TestCase): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): diff --git a/torchvision/csrc/ops/cpu/roi_align_common.h b/torchvision/csrc/ops/cpu/roi_align_common.h new file mode 100644 index 00000000000..e10c67b5b79 --- /dev/null +++ b/torchvision/csrc/ops/cpu/roi_align_common.h @@ -0,0 +1,128 @@ +#pragma once + +#include + +namespace vision { +namespace ops { +namespace detail { + +template +struct PreCalc { + int pos1; + int pos2; + int pos3; + int pos4; + T w1; + T w2; + T w3; + T w4; +}; + +// This helper computes the interpolation weights (w1, w2...) for every sampling +// point of a given box. There are pool_height * pool_width * roi_bin_grid_h * +// roi_bin_grid_w such sampling points. +// +// The weights (w1, w2...) are computed as the areas in this figure: +// https://en.wikipedia.org/wiki/Bilinear_interpolation#/media/File:Bilinear_interpolation_visualisation.svg +// and pos1, pos2 etc correspond to the indices of their respective pixels. +// +// Note: the weights and indices are shared across all channels, which is why +// they are pre-calculated prior to the main loop in the RoIAlign kernel. +// implementation taken from Caffe2 +template +void pre_calc_for_bilinear_interpolate( + int height, + int width, + int pooled_height, + int pooled_width, + T roi_start_h, + T roi_start_w, + T bin_size_h, + T bin_size_w, + int roi_bin_grid_h, + int roi_bin_grid_w, + std::vector>& pre_calc) { + int pre_calc_index = 0; + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T yy = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T xx = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T x = xx; + T y = yy; + // deal with: inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + PreCalc pc; + pc.pos1 = 0; + pc.pos2 = 0; + pc.pos3 = 0; + pc.pos4 = 0; + pc.w1 = 0; + pc.w2 = 0; + pc.w3 = 0; + pc.w4 = 0; + pre_calc[pre_calc_index] = pc; + pre_calc_index += 1; + continue; + } + + if (y <= 0) { + y = 0; + } + if (x <= 0) { + x = 0; + } + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + // save weights and indices + PreCalc pc; + pc.pos1 = y_low * width + x_low; + pc.pos2 = y_low * width + x_high; + pc.pos3 = y_high * width + x_low; + pc.pos4 = y_high * width + x_high; + pc.w1 = w1; + pc.w2 = w2; + pc.w3 = w3; + pc.w4 = w4; + pre_calc[pre_calc_index] = pc; + + pre_calc_index += 1; + } + } + } + } +} + +} // namespace detail +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/cpu/roi_align_kernel.cpp b/torchvision/csrc/ops/cpu/roi_align_kernel.cpp index dc0c38cd314..e6684e953d0 100644 --- a/torchvision/csrc/ops/cpu/roi_align_kernel.cpp +++ b/torchvision/csrc/ops/cpu/roi_align_kernel.cpp @@ -1,120 +1,13 @@ #include #include +#include "./roi_align_common.h" + namespace vision { namespace ops { namespace { -// implementation taken from Caffe2 -template -struct PreCalc { - int pos1; - int pos2; - int pos3; - int pos4; - T w1; - T w2; - T w3; - T w4; -}; - -template -void pre_calc_for_bilinear_interpolate( - int height, - int width, - int pooled_height, - int pooled_width, - int iy_upper, - int ix_upper, - T roi_start_h, - T roi_start_w, - T bin_size_h, - T bin_size_w, - int roi_bin_grid_h, - int roi_bin_grid_w, - std::vector>& pre_calc) { - int pre_calc_index = 0; - for (int ph = 0; ph < pooled_height; ph++) { - for (int pw = 0; pw < pooled_width; pw++) { - for (int iy = 0; iy < iy_upper; iy++) { - const T yy = roi_start_h + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 - for (int ix = 0; ix < ix_upper; ix++) { - const T xx = roi_start_w + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); - - T x = xx; - T y = yy; - // deal with: inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) { - // empty - PreCalc pc; - pc.pos1 = 0; - pc.pos2 = 0; - pc.pos3 = 0; - pc.pos4 = 0; - pc.w1 = 0; - pc.w2 = 0; - pc.w3 = 0; - pc.w4 = 0; - pre_calc[pre_calc_index] = pc; - pre_calc_index += 1; - continue; - } - - if (y <= 0) { - y = 0; - } - if (x <= 0) { - x = 0; - } - - int y_low = (int)y; - int x_low = (int)x; - int y_high; - int x_high; - - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = (T)y_low; - } else { - y_high = y_low + 1; - } - - if (x_low >= width - 1) { - x_high = x_low = width - 1; - x = (T)x_low; - } else { - x_high = x_low + 1; - } - - T ly = y - y_low; - T lx = x - x_low; - T hy = 1. - ly, hx = 1. - lx; - T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; - - // save weights and indeces - PreCalc pc; - pc.pos1 = y_low * width + x_low; - pc.pos2 = y_low * width + x_high; - pc.pos3 = y_high * width + x_low; - pc.pos4 = y_high * width + x_high; - pc.w1 = w1; - pc.w2 = w2; - pc.w3 = w3; - pc.w4 = w4; - pre_calc[pre_calc_index] = pc; - - pre_calc_index += 1; - } - } - } - } -} - template void roi_align_forward_kernel_impl( int n_rois, @@ -167,17 +60,15 @@ void roi_align_forward_kernel_impl( // When the grid is empty, output zeros. const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 - // we want to precalculate indeces and weights shared by all chanels, - // this is the key point of optimiation - std::vector> pre_calc( + // we want to precalculate indices and weights shared by all chanels, + // this is the key point of optimization + std::vector> pre_calc( roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height); - pre_calc_for_bilinear_interpolate( + detail::pre_calc_for_bilinear_interpolate( height, width, pooled_height, pooled_width, - roi_bin_grid_h, - roi_bin_grid_w, roi_start_h, roi_start_w, bin_size_h, @@ -199,7 +90,7 @@ void roi_align_forward_kernel_impl( T output_val = 0.; for (int iy = 0; iy < roi_bin_grid_h; iy++) { for (int ix = 0; ix < roi_bin_grid_w; ix++) { - PreCalc pc = pre_calc[pre_calc_index]; + detail::PreCalc pc = pre_calc[pre_calc_index]; output_val += pc.w1 * offset_input[pc.pos1] + pc.w2 * offset_input[pc.pos2] + pc.w3 * offset_input[pc.pos3] + pc.w4 * offset_input[pc.pos4]; @@ -207,7 +98,7 @@ void roi_align_forward_kernel_impl( pre_calc_index += 1; } } - output_val /= count; + output_val /= count; // Average pooling output[index] = output_val; } // for pw diff --git a/torchvision/csrc/ops/quantized/cpu/qroi_align_kernel.cpp b/torchvision/csrc/ops/quantized/cpu/qroi_align_kernel.cpp new file mode 100644 index 00000000000..e34b277747e --- /dev/null +++ b/torchvision/csrc/ops/quantized/cpu/qroi_align_kernel.cpp @@ -0,0 +1,208 @@ +#include +#include +#include + +#include "../../cpu/roi_align_common.h" + +namespace vision { +namespace ops { + +namespace { + +template +void qroi_align_forward_kernel_impl( + int n_rois, + const at::Tensor& t_input, + const float& spatial_scale, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + int sampling_ratio, + bool aligned, + const at::Tensor& t_rois, + T* output) { + const T* input = t_input.contiguous().data_ptr(); + int64_t input_zp = t_input.q_zero_point(); + float input_scale = t_input.q_scale(); + + const T* rois = t_rois.contiguous().data_ptr(); + int64_t rois_zp = t_rois.q_zero_point(); + float rois_scale = t_rois.q_scale(); + + for (int n = 0; n < n_rois; n++) { + int index_n = n * channels * pooled_width * pooled_height; + + const T* offset_rois = rois + n * 5; + + // FIXME: change this when batches of size > 1 are allowed + const int roi_batch_ind = 0; + + // Do not using rounding; this implementation detail is critical + float offset = aligned ? 0.5 : 0.; + float roi_start_w = + at::native::dequantize_val(rois_scale, rois_zp, offset_rois[1]) * + spatial_scale - + offset; + float roi_start_h = + at::native::dequantize_val(rois_scale, rois_zp, offset_rois[2]) * + spatial_scale - + offset; + float roi_end_w = + at::native::dequantize_val(rois_scale, rois_zp, offset_rois[3]) * + spatial_scale - + offset; + float roi_end_h = + at::native::dequantize_val(rois_scale, rois_zp, offset_rois[4]) * + spatial_scale - + offset; + + float roi_width = roi_end_w - roi_start_w; + float roi_height = roi_end_h - roi_start_h; + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = std::max(roi_width, 1.f); + roi_height = std::max(roi_height, 1.f); + } + + float bin_size_h = roi_height / pooled_height; + float bin_size_w = roi_width / pooled_width; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + // When the grid is empty, output zeros. + const float count = + std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 + + // we want to precalculate indices and weights shared by all chanels, + // this is the key point of optimization + std::vector> pre_calc( + roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height); + detail::pre_calc_for_bilinear_interpolate( + height, + width, + pooled_height, + pooled_width, + roi_start_h, + roi_start_w, + bin_size_h, + bin_size_w, + roi_bin_grid_h, + roi_bin_grid_w, + pre_calc); + + for (int c = 0; c < channels; c++) { + int index_n_c = index_n + c * pooled_width * pooled_height; + const T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; + int pre_calc_index = 0; + + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + int index = index_n_c + ph * pooled_width + pw; + + float output_val = 0.; + float sum_w = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + detail::PreCalc pc = pre_calc[pre_calc_index]; + + // Optimization: we use the raw values here and we'll dequantize + // later + output_val += pc.w1 * offset_input[pc.pos1].val_ + + pc.w2 * offset_input[pc.pos2].val_ + + pc.w3 * offset_input[pc.pos3].val_ + + pc.w4 * offset_input[pc.pos4].val_; + sum_w += pc.w1 + pc.w2 + pc.w3 + pc.w4; + + pre_calc_index += 1; + } + } + // Dequantize here + output_val = input_scale * (output_val - (float)input_zp * sum_w); + + output_val /= count; // Average pooling + + output[index] = + at::native::quantize_val(input_scale, input_zp, output_val); + } // for pw + } // for ph + } // for c + } // for n +} + +at::Tensor qroi_align_forward_kernel( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + bool aligned) { + TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); + TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + // The first column of the RoI tensor is an image index, but not all indices + // are representable depending on the quantization. For example 1, 3, 5... + // indices can't be represented when qscale is 2. To prevent any bug, we force + // a batch size of 1 and we ignore the first column + TORCH_CHECK( + input.size(0) == 1, + "Only one image per batch is allowed in roi_align when quantized tensors are passed."); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "qroi_align_forward_kernel"; + at::checkAllSameType(c, {input_t, rois_t}); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + // FIXME: This is private, API might change: + // https://github.com/pytorch/pytorch/wiki/Introducing-Quantized-Tensor#quantized-tensor-apis + at::Tensor output = at::_empty_affine_quantized( + {num_rois, channels, pooled_height, pooled_width}, + input.options(), + input.q_scale(), + input.q_zero_point()); + + if (output.numel() == 0) + return output; + + AT_DISPATCH_QINT_TYPES(input.scalar_type(), "qroi_align_forward_kernel", [&] { + qroi_align_forward_kernel_impl( + num_rois, + input, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + aligned, + rois, + output.data_ptr()); + }); + return output; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, QuantizedCPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::roi_align"), + TORCH_FN(qroi_align_forward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index ffcafc9f50d..c0ac14329d4 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -21,6 +21,7 @@ def roi_align( Args: input (Tensor[N, C, H, W]): input tensor + If the tensor is quantized, we expect a batch size of ``N == 1``. boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2) format where the regions will be taken from. The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.