Skip to content

Commit

Permalink
finish concat output
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 28, 2020
1 parent 37d7a19 commit 68c6866
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,14 @@ def nms_ir(
dimension are like the output of arange(num_anchors) if get_valid_counts
is not used before non_max_suppression.
out : Buffer
Output buffer, to be filled with sorted boxes.
out_bboxes : Buffer
Output buffer, to be filled with sorted box coordinates.
out_scores : Buffer
Output buffer, to be filled with sorted scores.
out_class_ids : Buffer
Output buffer, to be filled with sorted class ids.
box_indices : Buffer
A indices tensor mapping sorted indices to original indices
Expand Down Expand Up @@ -617,7 +623,8 @@ def nms_inner_loop(ib, j):
# Proceed to the inner loop if the box j is still valid
with ib.if_scope(out_scores[i, j] > -1.0):
with ib.if_scope(max_output_size > 0):
# No need to do more iteration if we have already reached max_output_size boxes
# No need to do more iteration if we have already reached max_output_size
# boxes
# TODO(masahi): Add TIR while loop to realize early exit from the outer loop
with ib.if_scope(num_valid_boxes_local[0] < max_output_size):
nms_inner_loop(ib, j)
Expand Down Expand Up @@ -666,6 +673,7 @@ def _fetch_score_ir(data, score, axis):


def _get_sorted_indices(data, data_buf, score_index, score_shape):
"""Extract a 1D score tensor from the packed input and do argsort on it."""
score_buf = tvm.tir.decl_buffer(score_shape, data.dtype, "score_buf", data_alignment=8)
score_tensor = te.extern(
[score_shape],
Expand Down Expand Up @@ -707,6 +715,7 @@ def _run_nms(
score_index,
return_indices,
):
"""Run NMS using sorted scores."""
sort_tensor_buf = tvm.tir.decl_buffer(
sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8
)
Expand Down Expand Up @@ -757,17 +766,18 @@ def _run_nms(


def _concatenate_outputs(
out_bboxes, out_scores, out_sorted_ids, out_shape, coord_start, score_index, id_index
out_bboxes, out_scores, out_class_ids, out_shape, coord_start, score_index, id_index
):
"""Pack the results from NMS into a single 5D or 6D tensor."""
batch_size = out_bboxes.shape[0]
num_anchors = out_bboxes.shape[1]

def ir(out_bboxes, out_scores, out_sorted_ids, out):
def ir(out_bboxes, out_scores, out_class_ids, out):
ib = tvm.tir.ir_builder.create()

out_bboxes = ib.buffer_ptr(out_bboxes)
out_scores = ib.buffer_ptr(out_scores)
out_sorted_ids = ib.buffer_ptr(out_sorted_ids)
out_class_ids = ib.buffer_ptr(out_class_ids)
out = ib.buffer_ptr(out)

with ib.if_scope(num_anchors > 0):
Expand All @@ -789,19 +799,14 @@ def ir(out_bboxes, out_scores, out_sorted_ids, out):
out[i, tid, coord_start + j] = out_bboxes[i, tid, j]
out[i, tid, score_index] = out_scores[i, tid]
if id_index >= 0:
out[i, tid, score_index] = out_scores[i, tid]
out[i, tid, id_index] = out_class_ids[i, tid]

return ib.get()

return te.extern(
[out_shape],
[out_bboxes, out_scores, out_sorted_ids],
lambda ins, outs: ir(
ins[0],
ins[1],
ins[2],
outs[0], # sorted bbox
),
[out_bboxes, out_scores, out_class_ids],
lambda ins, outs: ir(ins[0], ins[1], ins[2], outs[0]),
dtype=["float32"],
name="nms_output_concat",
tag="nms_output_concat",
Expand Down Expand Up @@ -903,7 +908,7 @@ def non_max_suppression(

sort_tensor = _get_sorted_indices(data, data_buf, score_index, (data.shape[0], data.shape[1]))

out_bboxes, out_scores, out_sorted_ids, box_indices, num_valid_boxes = _run_nms(
out_bboxes, out_scores, out_class_ids, box_indices, num_valid_boxes = _run_nms(
data,
data_buf,
sort_tensor,
Expand All @@ -923,5 +928,5 @@ def non_max_suppression(
return [box_indices, num_valid_boxes]

return _concatenate_outputs(
out_bboxes, out_scores, out_sorted_ids, data.shape, coord_start, score_index, id_index
out_bboxes, out_scores, out_class_ids, data.shape, coord_start, score_index, id_index
)

0 comments on commit 68c6866

Please sign in to comment.