diff --git a/test/test_ops.py b/test/test_ops.py index 4a73fe8ee74..a0943d48687 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -304,12 +304,12 @@ def test_qroialign(self): pool_size = 5 img_size = 10 n_channels = 2 - num_batches = 2 + num_imgs = 2 dtype = torch.float def make_rois(num_rois=1000): rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype) - rois[:, 0] = torch.randint(0, num_batches, size=(num_rois,)) # set batch index + rois[:, 0] = torch.randint(0, num_imgs, size=(num_rois,)) # set batch index rois[:, 3:] += rois[:, 1:3] # make sure boxes aren't degenerate return rois @@ -317,7 +317,7 @@ def make_rois(num_rois=1000): for scale, zero_point in ((1, 0), (2, 10), (0.1, 50)): for qdtype in (torch.qint8, torch.quint8, torch.qint32): - x = torch.randint(50, 100, size=(num_batches, n_channels, img_size, img_size)).to(dtype) + x = torch.randint(50, 100, size=(num_imgs, n_channels, img_size, img_size)).to(dtype) qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, dtype=qdtype) rois = make_rois() @@ -364,6 +364,13 @@ def make_rois(num_rois=1000): t_scale = torch.full_like(abs_diff, fill_value=scale) self.assertTrue(torch.allclose(abs_diff, t_scale, atol=1e-5)) + x = torch.randint(50, 100, size=(129, 3, 10, 10)).to(dtype) + qx = torch.quantize_per_tensor(x, scale=0, zero_point=1, dtype=torch.qint8) + rois = make_rois(10) + qrois = torch.quantize_per_tensor(rois, scale=0, zero_point=1, dtype=torch.qint8) + with self.assertRaisesRegex(RuntimeError, "There are 129 input images in the batch, but the RoIs tensor"): + ops.roi_align(qx, qrois, output_size=pool_size) + class PSRoIAlignTester(RoIOpTester, unittest.TestCase): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): diff --git a/torchvision/csrc/ops/quantized/cpu/qroi_align_kernel.cpp b/torchvision/csrc/ops/quantized/cpu/qroi_align_kernel.cpp index ad5a7f6166b..7b88cbe2b9b 100644 --- a/torchvision/csrc/ops/quantized/cpu/qroi_align_kernel.cpp +++ b/torchvision/csrc/ops/quantized/cpu/qroi_align_kernel.cpp @@ -36,8 +36,7 @@ void qroi_align_forward_kernel_impl( const T* offset_rois = rois + n * 5; int roi_batch_ind = at::native::dequantize_val( - rois_scale, rois_zp, offset_rois[0]); // FIXME: This can be out of the - // range of the quantized type!! + rois_scale, rois_zp, offset_rois[0]); // Do not using rounding; this implementation detail is critical float offset = aligned ? 0.5 : 0.; @@ -172,6 +171,16 @@ at::Tensor qroi_align_forward_kernel( return output; AT_DISPATCH_QINT_TYPES(input.scalar_type(), "qroi_align_forward_kernel", [&] { + // Note: q_max relates to the input tensor, but we need that of the rois + // tensor. They're the same since we make sure rois and input have the same + // type above. + uint64_t max_indexable = std::numeric_limits::max() + 1; + std::string err_msg = "There are " + std::to_string(input.size(0)) + + " input images in the batch, but the RoIs tensor can only index up to " + + std::to_string(max_indexable) + + " images. Try to reduce the batch size."; + TORCH_CHECK(input.size(0) <= max_indexable, err_msg); + qroi_align_forward_kernel_impl( num_rois, input,