diff --git a/.circleci/config.yml b/.circleci/config.yml index 7c79bbcf9e8..a3fc4ad8ac2 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -257,6 +257,7 @@ jobs: pip install --user --progress-bar off --editable . pip install --user onnx pip install --user onnxruntime + pip install --user pytest python test/test_onnx.py binary_linux_wheel: diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index 397cb8f9cc9..f2d2e472a58 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -257,6 +257,7 @@ jobs: pip install --user --progress-bar off --editable . pip install --user onnx pip install --user onnxruntime + pip install --user pytest python test/test_onnx.py binary_linux_wheel: diff --git a/test/test_onnx.py b/test/test_onnx.py index d0140c79dfc..c9455fbd86a 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -15,21 +15,19 @@ from torchvision.models.detection.image_list import ImageList from torchvision.models.detection.transform import GeneralizedRCNNTransform from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork -from torchvision.models.detection.backbone_utils import resnet_fpn_backbone from torchvision.models.detection.roi_heads import RoIHeads from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead -from torchvision.models.detection.mask_rcnn import MaskRCNNHeads, MaskRCNNPredictor from collections import OrderedDict -import unittest +import pytest from torchvision.ops._register_onnx_ops import _onnx_opset_version -@unittest.skipIf(onnxruntime is None, 'ONNX Runtime unavailable') -class ONNXExporterTester(unittest.TestCase): +@pytest.mark.skipif(onnxruntime is None, reason='ONNX Runtime unavailable') +class TestONNXExporter: @classmethod - def setUpClass(cls): + def setup_class(cls): torch.manual_seed(123) def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_constant_folding=True, dynamic_axes=None, @@ -80,7 +78,7 @@ def to_numpy(tensor): torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05) except AssertionError as error: if tolerate_small_mismatch: - self.assertIn("(0.00%)", str(error), str(error)) + assert "(0.00%)" in str(error), str(error) else: raise @@ -161,7 +159,7 @@ def test_roi_align_aligned(self): model = ops.RoIAlign((2, 2), 2.5, -1, aligned=True) self.run_model(model, [(x, single_roi)]) - @unittest.skip # Issue in exporting ROIAlign with aligned = True for malformed boxes + @pytest.mark.skip(reason="Issue in exporting ROIAlign with aligned = True for malformed boxes") def test_roi_align_malformed_boxes(self): x = torch.randn(1, 1, 10, 10, dtype=torch.float32) single_roi = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=torch.float32) @@ -527,4 +525,4 @@ def test_shufflenet_v2_dynamic_axes(self): if __name__ == '__main__': - unittest.main() + pytest.main([__file__])