From d43e801289621e71a79c71308dedeef0969264be Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 29 May 2021 15:23:09 +0900 Subject: [PATCH] cpu nms bug fixed --- python/tvm/topi/vision/nms.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index 4b7386f0fb38..1ea8247f5533 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -24,7 +24,7 @@ from ..sort import sort, argsort, topk from ..math import cast -from ..transform import reshape +from ..transform import reshape, arange, expand_dims from .. import reduction from ..scan import cumsum from .nms_util import ( @@ -756,16 +756,23 @@ def _collect_selected_indices_and_scores_ir( class_id = i % num_class with ib.for_range(0, num_boxes, name="j") as j: - offset = row_offsets[batch_id, class_id] + j + offset = j + class_id * num_boxes + collected_indices[batch_id, offset, 0] = zero + collected_indices[batch_id, offset, 1] = zero + collected_scores[batch_id, offset] = -1.0 + + with ib.for_range(0, batch_size * num_class, name="i", kind="parallel") as i: + i = cast(i, "int64") + batch_id = i // num_class + class_id = i % num_class + + with ib.for_range(0, num_boxes, name="j") as j: + offset = row_offsets[batch_id, class_id] with ib.if_scope(j < num_detections[batch_id, class_id]): - collected_indices[batch_id, offset, 0] = class_id - collected_indices[batch_id, offset, 1] = cast(selected_indices[i, j], "int64") - collected_scores[batch_id, offset] = selected_scores[i, j] - with ib.else_scope(): - collected_indices[batch_id, offset, 0] = zero - collected_indices[batch_id, offset, 1] = zero - collected_scores[batch_id, offset] = -1.0 + collected_indices[batch_id, offset + j, 0] = class_id + collected_indices[batch_id, offset + j, 1] = cast(selected_indices[i, j], "int64") + collected_scores[batch_id, offset + j] = selected_scores[i, j] return ib.get()