Skip to content

Commit

Permalink
add cpu impl
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 30, 2021
1 parent 787d839 commit 025010e
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 21 deletions.
6 changes: 3 additions & 3 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,6 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro


def _collect_selected_indices_and_scores_ir(
num_class,
selected_indices,
selected_scores,
num_detections,
Expand Down Expand Up @@ -1049,7 +1048,7 @@ def all_class_non_max_suppression(
max_output_boxes_per_class,
iou_threshold,
score_threshold,
max_total_size,
max_total_size=None,
output_format="onnx",
):
"""Non-maximum suppression operator for object detection, corresponding to ONNX
Expand Down Expand Up @@ -1132,6 +1131,7 @@ def all_class_non_max_suppression(
row_offsets,
_collect_selected_indices_and_scores_ir,
)
topk_indices = topk(selected_scores, k=max_total_size, axis=1, ret_type="indices")[0]
return post_process_max_detections(
selected_indices, selected_scores, num_total_detections, max_total_size, topk
selected_indices, topk_indices, num_total_detections, max_total_size
)
103 changes: 88 additions & 15 deletions python/tvm/topi/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,17 @@
from tvm.te import hybrid
from tvm.tir import if_then_else

from ..sort import sort, argsort
from ..sort import sort, argsort, topk
from ..math import cast
from ..transform import reshape
from .. import reduction
from ..scan import cumsum
from .nms_util import (
binary_search,
collect_selected_indices,
collect_selected_indices_and_scores,
run_all_class_nms,
post_process_max_detections,
)


Expand Down Expand Up @@ -727,8 +729,55 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro
return ib.get()


def _collect_selected_indices_and_scores_ir(
selected_indices,
selected_scores,
num_detections,
row_offsets,
collected_indices,
collected_scores,
):
batch_size, num_class = row_offsets.shape
num_boxes = selected_indices.shape[1]

ib = tvm.tir.ir_builder.create()

selected_indices = ib.buffer_ptr(selected_indices)
selected_scores = ib.buffer_ptr(selected_scores)
num_detections = ib.buffer_ptr(num_detections)
row_offsets = ib.buffer_ptr(row_offsets)
collected_indices = ib.buffer_ptr(collected_indices)
collected_scores = ib.buffer_ptr(collected_scores)
zero = cast(0, "int64")

with ib.for_range(0, batch_size * num_class, name="i", kind="parallel") as i:
i = cast(i, "int64")
batch_id = i // num_class
class_id = i % num_class

with ib.for_range(0, num_boxes, name="j") as j:
offset = row_offsets[batch_id, class_id] + j

with ib.if_scope(j < num_detections[batch_id, class_id]):
collected_indices[batch_id, offset, 0] = class_id
collected_indices[batch_id, offset, 1] = cast(selected_indices[i, j], "int64")
collected_scores[batch_id, offset] = selected_scores[i, j]
with ib.else_scope():
collected_indices[batch_id, offset, 0] = zero
collected_indices[batch_id, offset, 1] = zero
collected_scores[batch_id, offset] = -1.0

return ib.get()


def all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format
boxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
max_total_size=None,
output_format="onnx",
):
"""Non-maximum suppression operator for object detection, corresponding to ONNX
NonMaxSuppression and TensorFlow combined_non_max_suppression.
Expand All @@ -750,7 +799,7 @@ def all_class_non_max_suppression(
score_threshold : float or tvm.te.Tensor, optional
Score threshold to filter out low score boxes early
output_format : TODO
Returns
Expand All @@ -771,23 +820,47 @@ def all_class_non_max_suppression(
sorted_indices = argsort(scores, axis=1, is_ascend=False, dtype="int32")
valid_count = _get_valid_box_count(sorted_scores, score_threshold)

selected_indices, num_detections = run_all_class_nms(
if output_format == "onnx":
selected_indices, num_detections = run_all_class_nms(
boxes,
sorted_scores,
sorted_indices,
valid_count,
max_output_boxes_per_class,
iou_threshold,
_nms_loop,
)
row_offsets = cumsum(num_detections, exclusive=True, dtype="int64")
num_total_detections = reduction.sum(cast(num_detections, "int64"), axis=1)
selected_indices = collect_selected_indices(
num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir
)
return [selected_indices, num_total_detections]

selected_indices, selected_scores, num_detections = run_all_class_nms(
boxes,
sorted_scores,
sorted_indices,
valid_count,
max_output_boxes_per_class,
iou_threshold,
_nms_loop,
return_scores=True,
)
num_detections_per_batch = reshape(num_detections, (batch, num_class))
row_offsets = cumsum(num_detections_per_batch, exclusive=True, dtype="int64", axis=1)
num_total_detections = reduction.sum(cast(num_detections_per_batch, "int64"), axis=1)
selected_indices, selected_scores = collect_selected_indices_and_scores(
selected_indices,
selected_scores,
num_detections_per_batch,
row_offsets,
_collect_selected_indices_and_scores_ir,
)
topk_indices = topk(selected_scores, k=max_total_size, axis=1, ret_type="indices")
return post_process_max_detections(
selected_indices,
topk_indices,
num_total_detections,
max_total_size,
)

row_offsets = cumsum(num_detections, exclusive=True, dtype="int64")

num_total_detections = reduction.sum(cast(num_detections, "int64"), axis=1)

if output_format == "onnx":
selected_indices = collect_selected_indices(
num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir
)

return [selected_indices, num_total_detections]
5 changes: 2 additions & 3 deletions python/tvm/topi/vision/nms_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def collect_selected_indices_and_scores(
return te.extern(
[(batch_size, num_class * num_boxes, 2), (batch_size, num_class * num_boxes)],
[selected_indices, selected_scores, num_detections, row_offsets],
lambda ins, outs: ir(num_class, ins[0], ins[1], ins[2], ins[3], outs[0], outs[1]),
lambda ins, outs: ir(ins[0], ins[1], ins[2], ins[3], outs[0], outs[1]),
dtype=["int64", "float32"],
in_buffers=[selected_indices_buf, selected_scores_buf, num_detections_buf, row_offsets_buf],
name="collect_indices",
Expand Down Expand Up @@ -354,9 +354,8 @@ def run_all_class_nms(


def post_process_max_detections(
selected_indices, selected_scores, num_total_detections, max_total_size, topk_func
selected_indices, topk_indices, num_total_detections, max_total_size
):
topk_indices = topk_func(selected_scores, k=max_total_size, axis=1, ret_type="indices")[0]
topk_indices = expand_dims(topk_indices, axis=0)
final_indices = gather_nd(selected_indices, topk_indices, batch_dims=1)
num_detections = minimum(num_total_detections, max_total_size)
Expand Down

0 comments on commit 025010e

Please sign in to comment.