Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 29, 2021
1 parent 7b87922 commit 6c7aaeb
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 48 deletions.
59 changes: 13 additions & 46 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@
from .scan import exclusive_scan
from ..utils import ceil_div
from ..math import cast
from .. import reduction
from ..broadcast import minimum
from ..transform import reshape, strided_slice, gather_nd, expand_dims, squeeze
from ..transform import reshape
from ..vision.nms_util import (
calculate_overlap,
binary_search,
collect_selected_indices,
collect_selected_indices_and_scores,
run_all_class_nms,
post_process_max_detections,
)


Expand Down Expand Up @@ -990,7 +990,7 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro
return ib.get()


def _collect_selected_indices_tf_ir(
def _collect_selected_indices_and_scores_ir(
num_class,
selected_indices,
selected_scores,
Expand Down Expand Up @@ -1043,36 +1043,6 @@ def _collect_selected_indices_tf_ir(
return ib.get()


def collect_selected_indices_tf(selected_indices, selected_scores, num_detections, row_offsets):
batch_size, num_class = row_offsets.shape
num_boxes = selected_indices.shape[1]

selected_indices_buf = tvm.tir.decl_buffer(
selected_indices.shape, selected_indices.dtype, "selected_indices_buf", data_alignment=8
)
selected_scores_buf = tvm.tir.decl_buffer(
selected_scores.shape, selected_scores.dtype, "selected_scores_buf", data_alignment=8
)
num_detections_buf = tvm.tir.decl_buffer(
num_detections.shape, num_detections.dtype, "num_detections_buf", data_alignment=8
)
row_offsets_buf = tvm.tir.decl_buffer(
row_offsets.shape, row_offsets.dtype, "row_offsets_buf", data_alignment=8
)

return te.extern(
[(batch_size, num_class * num_boxes, 2), (batch_size, num_class * num_boxes)],
[selected_indices, selected_scores, num_detections, row_offsets],
lambda ins, outs: _collect_selected_indices_tf_ir(
num_class, ins[0], ins[1], ins[2], ins[3], outs[0], outs[1]
),
dtype=["int64", "float32"],
in_buffers=[selected_indices_buf, selected_scores_buf, num_detections_buf, row_offsets_buf],
name="collect_indices",
tag="collect_indices",
)


def all_class_non_max_suppression(
boxes,
scores,
Expand Down Expand Up @@ -1141,8 +1111,6 @@ def all_class_non_max_suppression(
)
return [selected_indices, num_total_detections]

max_detection_per_batch = max_total_size

selected_indices, selected_scores, num_detections = run_all_class_nms(
boxes,
sorted_scores,
Expand All @@ -1153,18 +1121,17 @@ def all_class_non_max_suppression(
_nms_loop,
return_scores=True,
)

# tf mode, return (batch_size, max_total_size, 2)
num_detections_per_batch = reshape(num_detections, (batch, num_class))
row_offsets, num_total_detections = exclusive_scan(
num_detections_per_batch, return_reduction=True, output_dtype="int64", axis=1
)
selected_indices, selected_scores = collect_selected_indices_tf(
selected_indices, selected_scores, num_detections_per_batch, row_offsets
selected_indices, selected_scores = collect_selected_indices_and_scores(
selected_indices,
selected_scores,
num_detections_per_batch,
row_offsets,
_collect_selected_indices_and_scores_ir,
)
return post_process_max_detections(
selected_indices, selected_scores, num_total_detections, max_total_size, topk
)
topk_indices = topk(selected_scores, k=max_detection_per_batch, axis=1, ret_type="indices")[0]
topk_indices = expand_dims(topk_indices, axis=0)
final_indices = gather_nd(selected_indices, topk_indices, batch_dims=1)
num_detections = minimum(num_total_detections, max_detection_per_batch)

return [final_indices, num_detections]
46 changes: 44 additions & 2 deletions python/tvm/topi/vision/nms_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
"""Common utilities used in Non-maximum suppression operators"""
import tvm
from tvm import te
from ..broadcast import minimum
from ..transform import gather_nd, expand_dims


def _get_boundaries(output, box_idx):
Expand Down Expand Up @@ -128,6 +130,36 @@ def collect_selected_indices(num_class, selected_indices, num_detections, row_of
)


def collect_selected_indices_and_scores(
selected_indices, selected_scores, num_detections, row_offsets, ir
):
batch_size, num_class = row_offsets.shape
num_boxes = selected_indices.shape[1]

selected_indices_buf = tvm.tir.decl_buffer(
selected_indices.shape, selected_indices.dtype, "selected_indices_buf", data_alignment=8
)
selected_scores_buf = tvm.tir.decl_buffer(
selected_scores.shape, selected_scores.dtype, "selected_scores_buf", data_alignment=8
)
num_detections_buf = tvm.tir.decl_buffer(
num_detections.shape, num_detections.dtype, "num_detections_buf", data_alignment=8
)
row_offsets_buf = tvm.tir.decl_buffer(
row_offsets.shape, row_offsets.dtype, "row_offsets_buf", data_alignment=8
)

return te.extern(
[(batch_size, num_class * num_boxes, 2), (batch_size, num_class * num_boxes)],
[selected_indices, selected_scores, num_detections, row_offsets],
lambda ins, outs: ir(num_class, ins[0], ins[1], ins[2], ins[3], outs[0], outs[1]),
dtype=["int64", "float32"],
in_buffers=[selected_indices_buf, selected_scores_buf, num_detections_buf, row_offsets_buf],
name="collect_indices",
tag="collect_indices",
)


def _all_class_nms_ir(
boxes,
sorted_scores,
Expand Down Expand Up @@ -208,7 +240,7 @@ def run_all_class_nms(
max_output_size_per_class,
iou_threshold,
nms_loop,
return_scores=False
return_scores=False,
):
"""The core all class NMS routine
Expand Down Expand Up @@ -276,7 +308,7 @@ def run_all_class_nms(
iou_threshold,
max_output_size_per_class,
outs[0], # box_indices
None, # scores
None, # scores
outs[1], # num_selected_boxes
nms_loop,
),
Expand Down Expand Up @@ -319,3 +351,13 @@ def run_all_class_nms(
name="all_class_nms",
tag="all_class_nms",
)


def post_process_max_detections(
selected_indices, selected_scores, num_total_detections, max_total_size, topk_func
):
topk_indices = topk_func(selected_scores, k=max_total_size, axis=1, ret_type="indices")[0]
topk_indices = expand_dims(topk_indices, axis=0)
final_indices = gather_nd(selected_indices, topk_indices, batch_dims=1)
num_detections = minimum(num_total_detections, max_total_size)
return [final_indices, num_detections]

0 comments on commit 6c7aaeb

Please sign in to comment.