diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index bb079222de78..40f62682cbde 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -27,14 +27,14 @@ from .scan import exclusive_scan from ..utils import ceil_div from ..math import cast -from .. import reduction -from ..broadcast import minimum -from ..transform import reshape, strided_slice, gather_nd, expand_dims, squeeze +from ..transform import reshape from ..vision.nms_util import ( calculate_overlap, binary_search, collect_selected_indices, + collect_selected_indices_and_scores, run_all_class_nms, + post_process_max_detections, ) @@ -990,7 +990,7 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro return ib.get() -def _collect_selected_indices_tf_ir( +def _collect_selected_indices_and_scores_ir( num_class, selected_indices, selected_scores, @@ -1043,36 +1043,6 @@ def _collect_selected_indices_tf_ir( return ib.get() -def collect_selected_indices_tf(selected_indices, selected_scores, num_detections, row_offsets): - batch_size, num_class = row_offsets.shape - num_boxes = selected_indices.shape[1] - - selected_indices_buf = tvm.tir.decl_buffer( - selected_indices.shape, selected_indices.dtype, "selected_indices_buf", data_alignment=8 - ) - selected_scores_buf = tvm.tir.decl_buffer( - selected_scores.shape, selected_scores.dtype, "selected_scores_buf", data_alignment=8 - ) - num_detections_buf = tvm.tir.decl_buffer( - num_detections.shape, num_detections.dtype, "num_detections_buf", data_alignment=8 - ) - row_offsets_buf = tvm.tir.decl_buffer( - row_offsets.shape, row_offsets.dtype, "row_offsets_buf", data_alignment=8 - ) - - return te.extern( - [(batch_size, num_class * num_boxes, 2), (batch_size, num_class * num_boxes)], - [selected_indices, selected_scores, num_detections, row_offsets], - lambda ins, outs: _collect_selected_indices_tf_ir( - num_class, ins[0], ins[1], ins[2], ins[3], outs[0], outs[1] - ), - dtype=["int64", "float32"], - in_buffers=[selected_indices_buf, selected_scores_buf, num_detections_buf, row_offsets_buf], - name="collect_indices", - tag="collect_indices", - ) - - def all_class_non_max_suppression( boxes, scores, @@ -1141,8 +1111,6 @@ def all_class_non_max_suppression( ) return [selected_indices, num_total_detections] - max_detection_per_batch = max_total_size - selected_indices, selected_scores, num_detections = run_all_class_nms( boxes, sorted_scores, @@ -1153,18 +1121,17 @@ def all_class_non_max_suppression( _nms_loop, return_scores=True, ) - - # tf mode, return (batch_size, max_total_size, 2) num_detections_per_batch = reshape(num_detections, (batch, num_class)) row_offsets, num_total_detections = exclusive_scan( num_detections_per_batch, return_reduction=True, output_dtype="int64", axis=1 ) - selected_indices, selected_scores = collect_selected_indices_tf( - selected_indices, selected_scores, num_detections_per_batch, row_offsets + selected_indices, selected_scores = collect_selected_indices_and_scores( + selected_indices, + selected_scores, + num_detections_per_batch, + row_offsets, + _collect_selected_indices_and_scores_ir, + ) + return post_process_max_detections( + selected_indices, selected_scores, num_total_detections, max_total_size, topk ) - topk_indices = topk(selected_scores, k=max_detection_per_batch, axis=1, ret_type="indices")[0] - topk_indices = expand_dims(topk_indices, axis=0) - final_indices = gather_nd(selected_indices, topk_indices, batch_dims=1) - num_detections = minimum(num_total_detections, max_detection_per_batch) - - return [final_indices, num_detections] diff --git a/python/tvm/topi/vision/nms_util.py b/python/tvm/topi/vision/nms_util.py index 1c2511d42cd9..fd18a466947c 100644 --- a/python/tvm/topi/vision/nms_util.py +++ b/python/tvm/topi/vision/nms_util.py @@ -18,6 +18,8 @@ """Common utilities used in Non-maximum suppression operators""" import tvm from tvm import te +from ..broadcast import minimum +from ..transform import gather_nd, expand_dims def _get_boundaries(output, box_idx): @@ -128,6 +130,36 @@ def collect_selected_indices(num_class, selected_indices, num_detections, row_of ) +def collect_selected_indices_and_scores( + selected_indices, selected_scores, num_detections, row_offsets, ir +): + batch_size, num_class = row_offsets.shape + num_boxes = selected_indices.shape[1] + + selected_indices_buf = tvm.tir.decl_buffer( + selected_indices.shape, selected_indices.dtype, "selected_indices_buf", data_alignment=8 + ) + selected_scores_buf = tvm.tir.decl_buffer( + selected_scores.shape, selected_scores.dtype, "selected_scores_buf", data_alignment=8 + ) + num_detections_buf = tvm.tir.decl_buffer( + num_detections.shape, num_detections.dtype, "num_detections_buf", data_alignment=8 + ) + row_offsets_buf = tvm.tir.decl_buffer( + row_offsets.shape, row_offsets.dtype, "row_offsets_buf", data_alignment=8 + ) + + return te.extern( + [(batch_size, num_class * num_boxes, 2), (batch_size, num_class * num_boxes)], + [selected_indices, selected_scores, num_detections, row_offsets], + lambda ins, outs: ir(num_class, ins[0], ins[1], ins[2], ins[3], outs[0], outs[1]), + dtype=["int64", "float32"], + in_buffers=[selected_indices_buf, selected_scores_buf, num_detections_buf, row_offsets_buf], + name="collect_indices", + tag="collect_indices", + ) + + def _all_class_nms_ir( boxes, sorted_scores, @@ -208,7 +240,7 @@ def run_all_class_nms( max_output_size_per_class, iou_threshold, nms_loop, - return_scores=False + return_scores=False, ): """The core all class NMS routine @@ -276,7 +308,7 @@ def run_all_class_nms( iou_threshold, max_output_size_per_class, outs[0], # box_indices - None, # scores + None, # scores outs[1], # num_selected_boxes nms_loop, ), @@ -319,3 +351,13 @@ def run_all_class_nms( name="all_class_nms", tag="all_class_nms", ) + + +def post_process_max_detections( + selected_indices, selected_scores, num_total_detections, max_total_size, topk_func +): + topk_indices = topk_func(selected_scores, k=max_total_size, axis=1, ret_type="indices")[0] + topk_indices = expand_dims(topk_indices, axis=0) + final_indices = gather_nd(selected_indices, topk_indices, batch_dims=1) + num_detections = minimum(num_total_detections, max_total_size) + return [final_indices, num_detections]