diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index a5508da65d1d..0112b9a81716 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -1124,6 +1124,7 @@ def all_class_non_max_suppression( row_offsets, num_total_detections = exclusive_scan( num_detections_per_batch, return_reduction=True, output_dtype="int64", axis=1 ) + selected_indices, selected_scores = collect_selected_indices_and_scores( selected_indices, selected_scores, @@ -1132,7 +1133,9 @@ def all_class_non_max_suppression( num_total_detections, _collect_selected_indices_and_scores_ir, ) + topk_indices = topk(selected_scores, k=max_total_size, axis=1, ret_type="indices")[0] + return post_process_max_detections( selected_indices, topk_indices, num_total_detections, max_total_size ) diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index b85c1e1c44c3..1fe03fc2a628 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -852,6 +852,7 @@ def all_class_non_max_suppression( num_detections_per_batch = reshape(num_detections, (batch, num_class)) row_offsets = cumsum(num_detections_per_batch, exclusive=True, dtype="int64", axis=1) num_total_detections = reduction.sum(cast(num_detections_per_batch, "int64"), axis=1) + selected_indices, selected_scores = collect_selected_indices_and_scores( selected_indices, selected_scores, @@ -860,7 +861,9 @@ def all_class_non_max_suppression( num_total_detections, _collect_selected_indices_and_scores_ir, ) + topk_indices = topk(selected_scores, k=max_total_size, axis=1, ret_type="indices") + return post_process_max_detections( selected_indices, topk_indices,