Skip to content

Commit

Permalink
begin supporting per batch output
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 29, 2021
1 parent 0044365 commit 5f349f7
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro


def all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, max_detection_per_batch=-1, output_format="onnx"
):
"""Non-maximum suppression operator for object detection, corresponding to ONNX
NonMaxSuppression and TensorFlow combined_non_max_suppression.
Expand All @@ -1011,7 +1011,7 @@ def all_class_non_max_suppression(
score_threshold : float or tvm.te.Tensor, optional
Score threshold to filter out low score boxes early
output_format : str
Returns
Expand All @@ -1027,6 +1027,9 @@ def all_class_non_max_suppression(
"""
batch, num_class, num_boxes = scores.shape

if max_detection_per_batch == -1:
max_detection_per_batch = num_class * num_boxes

scores = reshape(scores, (batch * num_class, num_boxes))
sorted_scores, sorted_indices = _dispatch_sort(scores, ret_type="both")
valid_count = _get_valid_box_count(sorted_scores, score_threshold)
Expand All @@ -1041,15 +1044,17 @@ def all_class_non_max_suppression(
_nms_loop,
)

row_offsets, num_total_detections = exclusive_scan(
num_detections, return_reduction=True, output_dtype="int64"
)

if output_format == "onnx":
row_offsets, num_total_detections = exclusive_scan(
num_detections, return_reduction=True, output_dtype="int64"
)
selected_indices = collect_selected_indices(
num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir
)
else:
selected_indices = reshape(selected_indices, (batch, num_class, num_boxes))
return [selected_indices, num_total_detections]

return [selected_indices, num_total_detections]
# tf mode, return (batch_size, max_total_size, 2)
num_detections_per_batch = reshape(num_detections, (batch, num_class))
row_offsets, num_total_detections = exclusive_scan(
num_detections_per_batch, return_reduction=True, output_dtype="int64", axis=1
)

0 comments on commit 5f349f7

Please sign in to comment.