Skip to content

Commit

Permalink
fixed topk handling
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 28, 2020
1 parent 1913f97 commit 37d7a19
Showing 1 changed file with 63 additions and 10 deletions.
73 changes: 63 additions & 10 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
j = bx * max_threads + tx
with ib.if_scope(j < nkeep):
src_idx = base_src_idx + sorted_index[i * num_anchors + j] * box_data_length
with ib.for_range(0, 4) as k:
with ib.for_range(0, 4, for_type="unroll") as k:
out_bboxes[(base_bbox_idx + j * 4 + k)] = data[src_idx + coord_start + k]

out_scores[i * num_anchors + j] = data[src_idx + score_index]
Expand All @@ -513,7 +513,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
# Only needed for return_indices = False case
if return_indices is False:
with ib.if_scope(j < num_anchors):
with ib.for_range(0, 4) as k:
with ib.for_range(0, 4, for_type="unroll") as k:
out_bboxes[(base_bbox_idx + j * 4 + k)] = -1.0

out_scores[i, j] = -1.0
Expand All @@ -529,7 +529,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
with ib.if_scope(j < valid_count[i]):
src_offset = base_src_idx + j * box_data_length

with ib.for_range(0, 4) as k:
with ib.for_range(0, 4, for_type="unroll") as k:
out_bboxes[base_bbox_idx + j * 4 + k] = data[src_offset + coord_start + k]
out_scores[i * num_anchors + j] = data[src_offset + score_index]

Expand Down Expand Up @@ -581,7 +581,7 @@ def nms_inner_loop(ib, j):

with ib.if_scope(
tvm.tir.all(
k < num_anchors,
k < valid_count[i],
out_scores[i, k] > 0, # is the box k still valid?
tvm.tir.any(
force_suppress > 0,
Expand Down Expand Up @@ -609,11 +609,15 @@ def nms_inner_loop(ib, j):

with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
# Apply nms
with ib.for_range(0, valid_count[i]) as j:
nkeep = if_then_else(
tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]
)

with ib.for_range(0, nkeep) as j:
# Proceed to the inner loop if the box j is still valid
with ib.if_scope(out_scores[i, j] > -1.0):
with ib.if_scope(max_output_size > 0):
# No need to do more iteration if we already reach max_output_size boxes
# No need to do more iteration if we have already reached max_output_size boxes
# TODO(masahi): Add TIR while loop to realize early exit from the outer loop
with ib.if_scope(num_valid_boxes_local[0] < max_output_size):
nms_inner_loop(ib, j)
Expand Down Expand Up @@ -752,9 +756,56 @@ def _run_nms(
)


def _concatenate_outputs(out_bboxes, out_scores, out_sorted_ids, score_index, id_index):
# TODO
return None
def _concatenate_outputs(
out_bboxes, out_scores, out_sorted_ids, out_shape, coord_start, score_index, id_index
):
batch_size = out_bboxes.shape[0]
num_anchors = out_bboxes.shape[1]

def ir(out_bboxes, out_scores, out_sorted_ids, out):
ib = tvm.tir.ir_builder.create()

out_bboxes = ib.buffer_ptr(out_bboxes)
out_scores = ib.buffer_ptr(out_scores)
out_sorted_ids = ib.buffer_ptr(out_sorted_ids)
out = ib.buffer_ptr(out)

with ib.if_scope(num_anchors > 0):
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = ceil_div(num_anchors, nthread_tx)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
ib.scope_attr(by, "thread_extent", batch_size)

tid = bx * nthread_tx + tx
i = by

with ib.if_scope(tid < num_anchors):
with ib.for_range(0, 4, for_type="unroll") as j:
out[i, tid, coord_start + j] = out_bboxes[i, tid, j]
out[i, tid, score_index] = out_scores[i, tid]
if id_index >= 0:
out[i, tid, score_index] = out_scores[i, tid]

return ib.get()

return te.extern(
[out_shape],
[out_bboxes, out_scores, out_sorted_ids],
lambda ins, outs: ir(
ins[0],
ins[1],
ins[2],
outs[0], # sorted bbox
),
dtype=["float32"],
name="nms_output_concat",
tag="nms_output_concat",
)


def non_max_suppression(
Expand Down Expand Up @@ -871,4 +922,6 @@ def non_max_suppression(
if return_indices:
return [box_indices, num_valid_boxes]

return _concatenate_outputs(out_bboxes, out_scores, out_sorted_ids, score_index, id_index)
return _concatenate_outputs(
out_bboxes, out_scores, out_sorted_ids, data.shape, coord_start, score_index, id_index
)

0 comments on commit 37d7a19

Please sign in to comment.