From afad2a2e920c98d269c6000035f31392cff7b6a3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 29 May 2021 16:29:53 +0900 Subject: [PATCH] collect indices and scores in one kernel --- python/tvm/topi/cuda/nms.py | 29 +++++++++++++------------ python/tvm/topi/vision/nms.py | 34 ++++++++++++++++-------------- python/tvm/topi/vision/nms_util.py | 20 ++++++++++++++---- 3 files changed, 50 insertions(+), 33 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 8783541819d3..360d568030be 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -995,6 +995,7 @@ def _collect_selected_indices_and_scores_ir( selected_scores, num_detections, row_offsets, + num_total_detections, collected_indices, collected_scores, ): @@ -1007,6 +1008,7 @@ def _collect_selected_indices_and_scores_ir( selected_scores = ib.buffer_ptr(selected_scores) num_detections = ib.buffer_ptr(num_detections) row_offsets = ib.buffer_ptr(row_offsets) + num_total_detections = ib.buffer_ptr(num_total_detections) collected_indices = ib.buffer_ptr(collected_indices) collected_scores = ib.buffer_ptr(collected_scores) @@ -1028,23 +1030,23 @@ def _collect_selected_indices_and_scores_ir( batch_id = idy // num_class class_id = idy % num_class - with ib.if_scope(idx < num_boxes): - offset = idx + class_id * num_boxes - collected_indices[batch_id, offset, 0] = zero - collected_indices[batch_id, offset, 1] = zero - collected_scores[batch_id, offset] = -1.0 - - with ib.new_scope(): - idx = bx * nthread_tx + tx - idy = cast(by, "int64") - batch_id = idy // num_class - class_id = idy % num_class - offset = row_offsets[batch_id, class_id] + idx - with ib.if_scope(idx < num_detections[batch_id, class_id]): + offset = row_offsets[batch_id, class_id] + idx collected_indices[batch_id, offset, 0] = class_id collected_indices[batch_id, offset, 1] = cast(selected_indices[idy, idx], "int64") collected_scores[batch_id, offset] = selected_scores[idy, idx] + with ib.else_scope(): + with ib.if_scope(idx < num_boxes): + offset = ( + num_total_detections[batch_id] + + class_id * num_boxes + - row_offsets[batch_id, class_id] + + idx + - num_detections[batch_id, class_id] + ) + collected_indices[batch_id, offset, 0] = zero + collected_indices[batch_id, offset, 1] = zero + collected_scores[batch_id, offset] = -1.0 return ib.get() @@ -1136,6 +1138,7 @@ def all_class_non_max_suppression( selected_scores, num_detections_per_batch, row_offsets, + num_total_detections, _collect_selected_indices_and_scores_ir, ) topk_indices = topk(selected_scores, k=max_total_size, axis=1, ret_type="indices")[0] diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index 1ea8247f5533..f7e279e17dd0 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -734,6 +734,7 @@ def _collect_selected_indices_and_scores_ir( selected_scores, num_detections, row_offsets, + num_total_detections, collected_indices, collected_scores, ): @@ -746,6 +747,7 @@ def _collect_selected_indices_and_scores_ir( selected_scores = ib.buffer_ptr(selected_scores) num_detections = ib.buffer_ptr(num_detections) row_offsets = ib.buffer_ptr(row_offsets) + num_total_detections = ib.buffer_ptr(num_total_detections) collected_indices = ib.buffer_ptr(collected_indices) collected_scores = ib.buffer_ptr(collected_scores) zero = cast(0, "int64") @@ -756,23 +758,22 @@ def _collect_selected_indices_and_scores_ir( class_id = i % num_class with ib.for_range(0, num_boxes, name="j") as j: - offset = j + class_id * num_boxes - collected_indices[batch_id, offset, 0] = zero - collected_indices[batch_id, offset, 1] = zero - collected_scores[batch_id, offset] = -1.0 - - with ib.for_range(0, batch_size * num_class, name="i", kind="parallel") as i: - i = cast(i, "int64") - batch_id = i // num_class - class_id = i % num_class - - with ib.for_range(0, num_boxes, name="j") as j: - offset = row_offsets[batch_id, class_id] - with ib.if_scope(j < num_detections[batch_id, class_id]): - collected_indices[batch_id, offset + j, 0] = class_id - collected_indices[batch_id, offset + j, 1] = cast(selected_indices[i, j], "int64") - collected_scores[batch_id, offset + j] = selected_scores[i, j] + offset = row_offsets[batch_id, class_id] + j + collected_indices[batch_id, offset, 0] = class_id + collected_indices[batch_id, offset, 1] = cast(selected_indices[i, j], "int64") + collected_scores[batch_id, offset] = selected_scores[i, j] + with ib.else_scope(): + offset = ( + num_total_detections[batch_id] + + class_id * num_boxes + - row_offsets[batch_id, class_id] + + j + - num_detections[batch_id, class_id] + ) + collected_indices[batch_id, offset, 0] = zero + collected_indices[batch_id, offset, 1] = zero + collected_scores[batch_id, offset] = -1.0 return ib.get() @@ -862,6 +863,7 @@ def all_class_non_max_suppression( selected_scores, num_detections_per_batch, row_offsets, + num_total_detections, _collect_selected_indices_and_scores_ir, ) topk_indices = topk(selected_scores, k=max_total_size, axis=1, ret_type="indices") diff --git a/python/tvm/topi/vision/nms_util.py b/python/tvm/topi/vision/nms_util.py index 72b791c41e04..4422aee1fb80 100644 --- a/python/tvm/topi/vision/nms_util.py +++ b/python/tvm/topi/vision/nms_util.py @@ -131,7 +131,7 @@ def collect_selected_indices(num_class, selected_indices, num_detections, row_of def collect_selected_indices_and_scores( - selected_indices, selected_scores, num_detections, row_offsets, ir + selected_indices, selected_scores, num_detections, row_offsets, num_total_detections, ir ): batch_size, num_class = row_offsets.shape num_boxes = selected_indices.shape[1] @@ -148,13 +148,25 @@ def collect_selected_indices_and_scores( row_offsets_buf = tvm.tir.decl_buffer( row_offsets.shape, row_offsets.dtype, "row_offsets_buf", data_alignment=8 ) + num_total_detections_buf = tvm.tir.decl_buffer( + num_total_detections.shape, + num_total_detections.dtype, + "num_total_detections_buf", + data_alignment=8, + ) return te.extern( [(batch_size, num_class * num_boxes, 2), (batch_size, num_class * num_boxes)], - [selected_indices, selected_scores, num_detections, row_offsets], - lambda ins, outs: ir(ins[0], ins[1], ins[2], ins[3], outs[0], outs[1]), + [selected_indices, selected_scores, num_detections, row_offsets, num_total_detections], + lambda ins, outs: ir(ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], outs[1]), dtype=["int64", "float32"], - in_buffers=[selected_indices_buf, selected_scores_buf, num_detections_buf, row_offsets_buf], + in_buffers=[ + selected_indices_buf, + selected_scores_buf, + num_detections_buf, + row_offsets_buf, + num_total_detections_buf, + ], name="collect_indices", tag="collect_indices", )