Skip to content

Commit

Permalink
Added check for index upper bound
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Apr 7, 2021
1 parent 07f3374 commit d6f78ab
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
13 changes: 10 additions & 3 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,20 +304,20 @@ def test_qroialign(self):
pool_size = 5
img_size = 10
n_channels = 2
num_batches = 2
num_imgs = 2
dtype = torch.float

def make_rois(num_rois=1000):
rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype)
rois[:, 0] = torch.randint(0, num_batches, size=(num_rois,)) # set batch index
rois[:, 0] = torch.randint(0, num_imgs, size=(num_rois,)) # set batch index
rois[:, 3:] += rois[:, 1:3] # make sure boxes aren't degenerate
return rois

for aligned in (True, False):
for scale, zero_point in ((1, 0), (2, 10), (0.1, 50)):
for qdtype in (torch.qint8, torch.quint8, torch.qint32):

x = torch.randint(50, 100, size=(num_batches, n_channels, img_size, img_size)).to(dtype)
x = torch.randint(50, 100, size=(num_imgs, n_channels, img_size, img_size)).to(dtype)
qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, dtype=qdtype)

rois = make_rois()
Expand Down Expand Up @@ -364,6 +364,13 @@ def make_rois(num_rois=1000):
t_scale = torch.full_like(abs_diff, fill_value=scale)
self.assertTrue(torch.allclose(abs_diff, t_scale, atol=1e-5))

x = torch.randint(50, 100, size=(129, 3, 10, 10)).to(dtype)
qx = torch.quantize_per_tensor(x, scale=0, zero_point=1, dtype=torch.qint8)
rois = make_rois(10)
qrois = torch.quantize_per_tensor(rois, scale=0, zero_point=1, dtype=torch.qint8)
with self.assertRaisesRegex(RuntimeError, "There are 129 input images in the batch, but the RoIs tensor"):
ops.roi_align(qx, qrois, output_size=pool_size)


class PSRoIAlignTester(RoIOpTester, unittest.TestCase):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
Expand Down
13 changes: 11 additions & 2 deletions torchvision/csrc/ops/quantized/cpu/qroi_align_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ void qroi_align_forward_kernel_impl(

const T* offset_rois = rois + n * 5;
int roi_batch_ind = at::native::dequantize_val(
rois_scale, rois_zp, offset_rois[0]); // FIXME: This can be out of the
// range of the quantized type!!
rois_scale, rois_zp, offset_rois[0]);

// Do not using rounding; this implementation detail is critical
float offset = aligned ? 0.5 : 0.;
Expand Down Expand Up @@ -172,6 +171,16 @@ at::Tensor qroi_align_forward_kernel(
return output;

AT_DISPATCH_QINT_TYPES(input.scalar_type(), "qroi_align_forward_kernel", [&] {
// Note: q_max relates to the input tensor, but we need that of the rois
// tensor. They're the same since we make sure rois and input have the same
// type above.
uint64_t max_indexable = std::numeric_limits<underlying_t>::max() + 1;
std::string err_msg = "There are " + std::to_string(input.size(0)) +
" input images in the batch, but the RoIs tensor can only index up to " +
std::to_string(max_indexable) +
" images. Try to reduce the batch size.";
TORCH_CHECK(input.size(0) <= max_indexable, err_msg);

qroi_align_forward_kernel_impl<scalar_t>(
num_rois,
input,
Expand Down

0 comments on commit d6f78ab

Please sign in to comment.