From e43b80de993fcbb79ac3c10c2cffd97bccf50941 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Thu, 22 Apr 2021 09:45:03 -0600 Subject: [PATCH] [ONNX] Support NMS Center Box (#7900) * [ONNX] Support NMS Center Box * fix silly mistake in contional --- python/tvm/relay/frontend/onnx.py | 16 +++++++++++----- tests/python/frontend/onnx/test_forward.py | 1 - 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 4b159a5716892..fa0eac9bb15e4 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2543,11 +2543,17 @@ def _impl_v10(cls, inputs, attr, params): iou_threshold = inputs[3] score_threshold = inputs[4] - if "center_point_box" in attr: - if attr["center_point_box"] != 0: - raise NotImplementedError( - "Only support center_point_box = 0 in ONNX NonMaxSuprresion" - ) + boxes_dtype = infer_type(boxes).checked_type.dtype + + if attr.get("center_point_box", 0) != 0: + xc, yc, w, h = _op.split(boxes, 4, axis=2) + half_w = w / _expr.const(2.0, boxes_dtype) + half_h = h / _expr.const(2.0, boxes_dtype) + x1 = xc - half_w + x2 = xc + half_w + y1 = yc - half_h + y2 = yc + half_h + boxes = _op.concatenate([y1, x1, y2, x2], axis=2) if iou_threshold is None: iou_threshold = _expr.const(0.0, dtype="float32") diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 0a702c5696e04..1eaae6f30ea28 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4215,7 +4215,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_maxpool_with_argmax_2d_precomputed_strides/", "test_maxunpool_export_with_output_shape/", "test_mvn/", - "test_nonmaxsuppression_center_point_box_format/", "test_qlinearconv/", "test_qlinearmatmul_2D/", "test_qlinearmatmul_3D/",