From ca9470ba68e68c81902b0a3bad4bf5b5f0aa311e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 30 May 2021 05:25:13 +0900 Subject: [PATCH] update op definition --- python/tvm/relay/frontend/tensorflow.py | 11 +++++++++- python/tvm/relay/op/vision/nms.py | 6 +++++- python/tvm/topi/cuda/nms.py | 7 +------ python/tvm/topi/vision/nms.py | 12 ++--------- python/tvm/topi/vision/nms_util.py | 11 ---------- src/relay/op/vision/nms.cc | 28 ++++++++++++++----------- 6 files changed, 34 insertions(+), 41 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index a1c9d7efd4b5..0318c50c9b8c 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -804,7 +804,11 @@ def all_class_impl( clip_boxes, mod, ): - indices, num_detections = _op.vision.all_class_non_max_suppression( + ( + selected_indices, + selected_scores, + num_detections, + ) = _op.vision.all_class_non_max_suppression( boxes, scores, max_output_boxes_per_class, @@ -813,6 +817,11 @@ def all_class_impl( max_total_size, output_format="tensorflow", ) + topk_indices = _op.topk(selected_scores, k=max_total_size, axis=1, ret_type="indices") + topk_indices = _op.expand_dims(topk_indices, axis=0) + indices = _op.gather_nd(selected_indices, topk_indices, batch_dims=1) + num_detections = _op.minimum(num_detections, _op.const(max_total_size, dtype="int64")) + nmsed_box_indices = _op.take(indices, _op.const(1), axis=2) nmsed_classes = _op.cast(_op.take(indices, _op.const(0), axis=2), "float32") nmsed_boxes = _op.gather_nd(boxes, _op.expand_dims(nmsed_box_indices, axis=0), batch_dims=1) diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index 785579cd7973..dc1a08e4b6ac 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -209,4 +209,8 @@ def all_class_non_max_suppression( max_total_size, output_format, ) - return expr.TupleWrapper(out, 2) + + if output_format == "onnx": + return expr.TupleWrapper(out, 2) + + return expr.TupleWrapper(out, 3) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 0112b9a81716..e39113861331 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -34,7 +34,6 @@ collect_selected_indices, collect_selected_indices_and_scores, run_all_class_nms, - post_process_max_detections, ) @@ -1134,8 +1133,4 @@ def all_class_non_max_suppression( _collect_selected_indices_and_scores_ir, ) - topk_indices = topk(selected_scores, k=max_total_size, axis=1, ret_type="indices")[0] - - return post_process_max_detections( - selected_indices, topk_indices, num_total_detections, max_total_size - ) + return [selected_indices, selected_scores, num_total_detections] diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index 65dc6136b2a6..7485719627c4 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -22,7 +22,7 @@ from tvm.te import hybrid from tvm.tir import if_then_else -from ..sort import argsort, topk +from ..sort import argsort from ..math import cast from ..transform import reshape, gather from .. import reduction @@ -32,7 +32,6 @@ collect_selected_indices, collect_selected_indices_and_scores, run_all_class_nms, - post_process_max_detections, ) @@ -862,11 +861,4 @@ def all_class_non_max_suppression( _collect_selected_indices_and_scores_ir, ) - topk_indices = topk(selected_scores, k=max_total_size, axis=1, ret_type="indices") - - return post_process_max_detections( - selected_indices, - topk_indices, - num_total_detections, - max_total_size, - ) + return [selected_indices, selected_scores, num_total_detections] diff --git a/python/tvm/topi/vision/nms_util.py b/python/tvm/topi/vision/nms_util.py index 86b6c7b3b5e6..dfa3c0788295 100644 --- a/python/tvm/topi/vision/nms_util.py +++ b/python/tvm/topi/vision/nms_util.py @@ -18,8 +18,6 @@ """Common utilities used in Non-maximum suppression operators""" import tvm from tvm import te -from ..broadcast import minimum -from ..transform import gather_nd, expand_dims def _get_boundaries(output, box_idx): @@ -302,12 +300,3 @@ def run_all_class_nms( name="all_class_nms", tag="all_class_nms", ) - - -def post_process_max_detections( - selected_indices, topk_indices, num_total_detections, max_total_size -): - topk_indices = expand_dims(topk_indices, axis=0) - final_indices = gather_nd(selected_indices, topk_indices, batch_dims=1) - num_detections = minimum(num_total_detections, max_total_size) - return [final_indices, num_detections] diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 1e63ccd04721..718c9c0a3857 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -152,29 +152,33 @@ bool AllClassNMSRel(const Array& types, int num_inputs, const Attrs& attrs IndexExpr num_classes = scores_shape[1]; IndexExpr num_boxes = boxes_shape[1]; - IndexExpr num_total_boxes = Any(); - if (!batch.as() && !num_boxes.as()) { - num_total_boxes = batch * num_classes * num_boxes; - } - const auto* param = attrs.as(); CHECK(param); std::vector fields; if (param->output_format == "onnx") { + IndexExpr num_total_boxes = Any(); + if (!batch.as() && !num_boxes.as()) { + num_total_boxes = batch * num_classes * num_boxes; + } std::vector oshape{num_total_boxes, 3}; - std::vector countshape{1}; + std::vector counts_shape{1}; fields.push_back(TensorType(oshape, DataType::Int(64))); - fields.push_back(TensorType(countshape, DataType::Int(64))); + fields.push_back(TensorType(counts_shape, DataType::Int(64))); } else { ICHECK(param->max_total_size) << "max_total_size required for tf mode"; Integer max_total_size = param->max_total_size.value(); - std::vector oshape{batch, max_total_size, 2}; - std::vector countshape{batch}; - fields.push_back(TensorType(oshape, DataType::Int(64))); - fields.push_back(TensorType(countshape, DataType::Int(64))); + IndexExpr num_total_boxes_per_batch = Any(); + if (!num_boxes.as()) { + num_total_boxes_per_batch = num_classes * num_boxes; + } + std::vector indices_shape{batch, num_total_boxes_per_batch, 2}; + std::vector scores_shape{batch, num_total_boxes_per_batch}; + std::vector counts_shape{batch}; + fields.push_back(TensorType(indices_shape, DataType::Int(64))); + fields.push_back(TensorType(scores_shape, DataType::Float(32))); + fields.push_back(TensorType(counts_shape, DataType::Int(64))); } - reporter->Assign(types[5], TupleType(Array(fields))); return true; }