diff --git a/test/test_ops.py b/test/test_ops.py index 86e1c2b0ba7..9eb6342e378 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -548,5 +548,30 @@ def test_frozenbatchnorm2d_repr(self): self.assertEqual(t.__repr__(), expected_string) +class BoxConversionTester(unittest.TestCase): + @staticmethod + def _get_box_sequences(): + # Define here the argument type of `boxes` supported by region pooling operations + box_tensor = torch.tensor([[0, 0, 0, 100, 100], [1, 0, 0, 100, 100]], dtype=torch.float) + box_list = [torch.tensor([[0, 0, 100, 100]], dtype=torch.float), + torch.tensor([[0, 0, 100, 100]], dtype=torch.float)] + box_tuple = tuple(box_list) + return box_tensor, box_list, box_tuple + + def test_check_roi_boxes_shape(self): + # Ensure common sequences of tensors are supported + for box_sequence in self._get_box_sequences(): + self.assertIsNone(ops._utils.check_roi_boxes_shape(box_sequence)) + + def test_convert_boxes_to_roi_format(self): + # Ensure common sequences of tensors yield the same result + ref_tensor = None + for box_sequence in self._get_box_sequences(): + if ref_tensor is None: + ref_tensor = box_sequence + else: + self.assertTrue(torch.equal(ref_tensor, ops._utils.convert_boxes_to_roi_format(box_sequence))) + + if __name__ == '__main__': unittest.main() diff --git a/torchvision/ops/_utils.py b/torchvision/ops/_utils.py index 714022f0421..f514664042b 100644 --- a/torchvision/ops/_utils.py +++ b/torchvision/ops/_utils.py @@ -27,7 +27,7 @@ def convert_boxes_to_roi_format(boxes): def check_roi_boxes_shape(boxes): - if isinstance(boxes, list): + if isinstance(boxes, (list, tuple)): for _tensor in boxes: assert _tensor.size(1) == 4, \ 'The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]'