Skip to content

Commit

Permalink
collect indices and scores in one kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 30, 2021
1 parent 2b441c3 commit 480f6b7
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 33 deletions.
29 changes: 16 additions & 13 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,7 @@ def _collect_selected_indices_and_scores_ir(
selected_scores,
num_detections,
row_offsets,
num_total_detections,
collected_indices,
collected_scores,
):
Expand All @@ -1007,6 +1008,7 @@ def _collect_selected_indices_and_scores_ir(
selected_scores = ib.buffer_ptr(selected_scores)
num_detections = ib.buffer_ptr(num_detections)
row_offsets = ib.buffer_ptr(row_offsets)
num_total_detections = ib.buffer_ptr(num_total_detections)
collected_indices = ib.buffer_ptr(collected_indices)
collected_scores = ib.buffer_ptr(collected_scores)

Expand All @@ -1028,23 +1030,23 @@ def _collect_selected_indices_and_scores_ir(
batch_id = idy // num_class
class_id = idy % num_class

with ib.if_scope(idx < num_boxes):
offset = idx + class_id * num_boxes
collected_indices[batch_id, offset, 0] = zero
collected_indices[batch_id, offset, 1] = zero
collected_scores[batch_id, offset] = -1.0

with ib.new_scope():
idx = bx * nthread_tx + tx
idy = cast(by, "int64")
batch_id = idy // num_class
class_id = idy % num_class
offset = row_offsets[batch_id, class_id] + idx

with ib.if_scope(idx < num_detections[batch_id, class_id]):
offset = row_offsets[batch_id, class_id] + idx
collected_indices[batch_id, offset, 0] = class_id
collected_indices[batch_id, offset, 1] = cast(selected_indices[idy, idx], "int64")
collected_scores[batch_id, offset] = selected_scores[idy, idx]
with ib.else_scope():
with ib.if_scope(idx < num_boxes):
offset = (
num_total_detections[batch_id]
+ class_id * num_boxes
- row_offsets[batch_id, class_id]
+ idx
- num_detections[batch_id, class_id]
)
collected_indices[batch_id, offset, 0] = zero
collected_indices[batch_id, offset, 1] = zero
collected_scores[batch_id, offset] = -1.0

return ib.get()

Expand Down Expand Up @@ -1136,6 +1138,7 @@ def all_class_non_max_suppression(
selected_scores,
num_detections_per_batch,
row_offsets,
num_total_detections,
_collect_selected_indices_and_scores_ir,
)
topk_indices = topk(selected_scores, k=max_total_size, axis=1, ret_type="indices")[0]
Expand Down
34 changes: 18 additions & 16 deletions python/tvm/topi/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,7 @@ def _collect_selected_indices_and_scores_ir(
selected_scores,
num_detections,
row_offsets,
num_total_detections,
collected_indices,
collected_scores,
):
Expand All @@ -746,6 +747,7 @@ def _collect_selected_indices_and_scores_ir(
selected_scores = ib.buffer_ptr(selected_scores)
num_detections = ib.buffer_ptr(num_detections)
row_offsets = ib.buffer_ptr(row_offsets)
num_total_detections = ib.buffer_ptr(num_total_detections)
collected_indices = ib.buffer_ptr(collected_indices)
collected_scores = ib.buffer_ptr(collected_scores)
zero = cast(0, "int64")
Expand All @@ -756,23 +758,22 @@ def _collect_selected_indices_and_scores_ir(
class_id = i % num_class

with ib.for_range(0, num_boxes, name="j") as j:
offset = j + class_id * num_boxes
collected_indices[batch_id, offset, 0] = zero
collected_indices[batch_id, offset, 1] = zero
collected_scores[batch_id, offset] = -1.0

with ib.for_range(0, batch_size * num_class, name="i", kind="parallel") as i:
i = cast(i, "int64")
batch_id = i // num_class
class_id = i % num_class

with ib.for_range(0, num_boxes, name="j") as j:
offset = row_offsets[batch_id, class_id]

with ib.if_scope(j < num_detections[batch_id, class_id]):
collected_indices[batch_id, offset + j, 0] = class_id
collected_indices[batch_id, offset + j, 1] = cast(selected_indices[i, j], "int64")
collected_scores[batch_id, offset + j] = selected_scores[i, j]
offset = row_offsets[batch_id, class_id] + j
collected_indices[batch_id, offset, 0] = class_id
collected_indices[batch_id, offset, 1] = cast(selected_indices[i, j], "int64")
collected_scores[batch_id, offset] = selected_scores[i, j]
with ib.else_scope():
offset = (
num_total_detections[batch_id]
+ class_id * num_boxes
- row_offsets[batch_id, class_id]
+ j
- num_detections[batch_id, class_id]
)
collected_indices[batch_id, offset, 0] = zero
collected_indices[batch_id, offset, 1] = zero
collected_scores[batch_id, offset] = -1.0

return ib.get()

Expand Down Expand Up @@ -862,6 +863,7 @@ def all_class_non_max_suppression(
selected_scores,
num_detections_per_batch,
row_offsets,
num_total_detections,
_collect_selected_indices_and_scores_ir,
)
topk_indices = topk(selected_scores, k=max_total_size, axis=1, ret_type="indices")
Expand Down
20 changes: 16 additions & 4 deletions python/tvm/topi/vision/nms_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def collect_selected_indices(num_class, selected_indices, num_detections, row_of


def collect_selected_indices_and_scores(
selected_indices, selected_scores, num_detections, row_offsets, ir
selected_indices, selected_scores, num_detections, row_offsets, num_total_detections, ir
):
batch_size, num_class = row_offsets.shape
num_boxes = selected_indices.shape[1]
Expand All @@ -148,13 +148,25 @@ def collect_selected_indices_and_scores(
row_offsets_buf = tvm.tir.decl_buffer(
row_offsets.shape, row_offsets.dtype, "row_offsets_buf", data_alignment=8
)
num_total_detections_buf = tvm.tir.decl_buffer(
num_total_detections.shape,
num_total_detections.dtype,
"num_total_detections_buf",
data_alignment=8,
)

return te.extern(
[(batch_size, num_class * num_boxes, 2), (batch_size, num_class * num_boxes)],
[selected_indices, selected_scores, num_detections, row_offsets],
lambda ins, outs: ir(ins[0], ins[1], ins[2], ins[3], outs[0], outs[1]),
[selected_indices, selected_scores, num_detections, row_offsets, num_total_detections],
lambda ins, outs: ir(ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], outs[1]),
dtype=["int64", "float32"],
in_buffers=[selected_indices_buf, selected_scores_buf, num_detections_buf, row_offsets_buf],
in_buffers=[
selected_indices_buf,
selected_scores_buf,
num_detections_buf,
row_offsets_buf,
num_total_detections_buf,
],
name="collect_indices",
tag="collect_indices",
)
Expand Down

0 comments on commit 480f6b7

Please sign in to comment.