From 5f349f77c9c230ee636aceb52547502319c8ad77 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 28 May 2021 06:54:34 +0900 Subject: [PATCH] begin supporting per batch output --- python/tvm/topi/cuda/nms.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 991a572cec2e9..c3e1a1857baf7 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -989,7 +989,7 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro def all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, max_detection_per_batch=-1, output_format="onnx" ): """Non-maximum suppression operator for object detection, corresponding to ONNX NonMaxSuppression and TensorFlow combined_non_max_suppression. @@ -1011,7 +1011,7 @@ def all_class_non_max_suppression( score_threshold : float or tvm.te.Tensor, optional Score threshold to filter out low score boxes early - + output_format : str Returns @@ -1027,6 +1027,9 @@ def all_class_non_max_suppression( """ batch, num_class, num_boxes = scores.shape + if max_detection_per_batch == -1: + max_detection_per_batch = num_class * num_boxes + scores = reshape(scores, (batch * num_class, num_boxes)) sorted_scores, sorted_indices = _dispatch_sort(scores, ret_type="both") valid_count = _get_valid_box_count(sorted_scores, score_threshold) @@ -1041,15 +1044,17 @@ def all_class_non_max_suppression( _nms_loop, ) - row_offsets, num_total_detections = exclusive_scan( - num_detections, return_reduction=True, output_dtype="int64" - ) - if output_format == "onnx": + row_offsets, num_total_detections = exclusive_scan( + num_detections, return_reduction=True, output_dtype="int64" + ) selected_indices = collect_selected_indices( num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir ) - else: - selected_indices = reshape(selected_indices, (batch, num_class, num_boxes)) + return [selected_indices, num_total_detections] - return [selected_indices, num_total_detections] + # tf mode, return (batch_size, max_total_size, 2) + num_detections_per_batch = reshape(num_detections, (batch, num_class)) + row_offsets, num_total_detections = exclusive_scan( + num_detections_per_batch, return_reduction=True, output_dtype="int64", axis=1 + )