diff --git a/test/test_onnx.py b/test/test_onnx.py index 03bd667ac24..cf9e1c6bb9d 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -365,11 +365,23 @@ def test_heatmaps_to_keypoints(self): assert torch.all(out2[0].eq(out_trace2[0])) assert torch.all(out2[1].eq(out_trace2[1])) - @unittest.skip("Disable test until Argmax is updated in ONNX") def test_keypoint_rcnn(self): - images, test_images = self.get_test_images() + class KeyPointRCNN(torch.nn.Module): + def __init__(self): + super(KeyPointRCNN, self).__init__() + self.model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, + min_size=200, + max_size=300) + + def forward(self, images): + output = self.model(images) + # TODO: The keypoints_scores require the use of Argmax that is updated in ONNX. + # For now we are testing all the output of KeypointRCNN except keypoints_scores. + # Enable When Argmax is updated in ONNX Runtime. + return output[0]['boxes'], output[0]['labels'], output[0]['scores'], output[0]['keypoints'] - model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) + images, test_images = self.get_test_images() + model = KeyPointRCNN() model.eval() model(test_images) self.run_model(model, [(images,), (test_images,)])