diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 360d568030be..a5508da65d1d 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -1100,25 +1100,6 @@ def all_class_non_max_suppression( sorted_scores, sorted_indices = _dispatch_sort(scores, ret_type="both") valid_count = _get_valid_box_count(sorted_scores, score_threshold) - if output_format == "onnx": - selected_indices, num_detections = run_all_class_nms( - boxes, - sorted_scores, - sorted_indices, - valid_count, - max_output_boxes_per_class, - iou_threshold, - _nms_loop, - ) - - row_offsets, num_total_detections = exclusive_scan( - num_detections, return_reduction=True, output_dtype="int64" - ) - selected_indices = collect_selected_indices( - num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir - ) - return [selected_indices, num_total_detections] - selected_indices, selected_scores, num_detections = run_all_class_nms( boxes, sorted_scores, @@ -1127,8 +1108,18 @@ def all_class_non_max_suppression( max_output_boxes_per_class, iou_threshold, _nms_loop, - return_scores=True, + return_scores=(output_format == "tensorflow"), ) + + if output_format == "onnx": + row_offsets, num_total_detections = exclusive_scan( + num_detections, return_reduction=True, output_dtype="int64" + ) + selected_indices = collect_selected_indices( + num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir + ) + return [selected_indices, num_total_detections] + num_detections_per_batch = reshape(num_detections, (batch, num_class)) row_offsets, num_total_detections = exclusive_scan( num_detections_per_batch, return_reduction=True, output_dtype="int64", axis=1 diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index 010879e8af7c..b85c1e1c44c3 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -829,23 +829,6 @@ def all_class_non_max_suppression( valid_count = _get_valid_box_count(sorted_scores, score_threshold) - if output_format == "onnx": - selected_indices, num_detections = run_all_class_nms( - boxes, - sorted_scores, - sorted_indices, - valid_count, - max_output_boxes_per_class, - iou_threshold, - _nms_loop, - ) - row_offsets = cumsum(num_detections, exclusive=True, dtype="int64") - num_total_detections = reduction.sum(cast(num_detections, "int64"), axis=1) - selected_indices = collect_selected_indices( - num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir - ) - return [selected_indices, num_total_detections] - selected_indices, selected_scores, num_detections = run_all_class_nms( boxes, sorted_scores, @@ -854,8 +837,18 @@ def all_class_non_max_suppression( max_output_boxes_per_class, iou_threshold, _nms_loop, - return_scores=True, + return_scores=(output_format == "tensorflow"), ) + + if output_format == "onnx": + row_offsets = cumsum(num_detections, exclusive=True, dtype="int64") + num_total_detections = reduction.sum(cast(num_detections, "int64"), axis=1) + + selected_indices = collect_selected_indices( + num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir + ) + return [selected_indices, num_total_detections] + num_detections_per_batch = reshape(num_detections, (batch, num_class)) row_offsets = cumsum(num_detections_per_batch, exclusive=True, dtype="int64", axis=1) num_total_detections = reduction.sum(cast(num_detections_per_batch, "int64"), axis=1) diff --git a/python/tvm/topi/vision/nms_util.py b/python/tvm/topi/vision/nms_util.py index 4422aee1fb80..654e2d078c42 100644 --- a/python/tvm/topi/vision/nms_util.py +++ b/python/tvm/topi/vision/nms_util.py @@ -306,7 +306,7 @@ def run_all_class_nms( ) if return_scores is False: - return te.extern( + selected_indices, num_detections = te.extern( [(batch_class, num_boxes), (1, batch_class)], [boxes, sorted_scores, sorted_indices, valid_count], lambda ins, outs: _all_class_nms_ir( @@ -334,6 +334,7 @@ def run_all_class_nms( name="all_class_nms", tag="all_class_nms", ) + return selected_indices, None, num_detections return te.extern( [(batch_class, num_boxes), (batch_class, num_boxes), (1, batch_class)],