From 8c68a350cb8084fbeace55d9494b6f7b2cbd7391 Mon Sep 17 00:00:00 2001 From: masahi Date: Thu, 15 Apr 2021 04:18:53 +0900 Subject: [PATCH] [TOPI, Relay] A new NMS op variant for ONNX NMS / TF Combined NMS (#7796) * initial import * add c++ boilarplate * add python boilarpolate * update onnx frontend * fixing * update onnx frontend * fix shape * minor update * fix * fix shape func * fix for no box * more fix * made things 64 bit * int64 tweak * max_output_size doesn't need to be a callback * remove all_class_nms schedule * minor simplify * remove expand_dim * refactoring * simplify nms loop * cpu all_class_nms stub * updating ir for cpu * working with cpu * update cpu strategy, relay op also working * fix cpplint * fixing pylint * enable gpu test for onnx nms * tweak parallel * pyformat and lint * fix relay nms test * doc update for cpp relay * updating tests * updated tests * fix converting score_threshold to Expr * update doc * doc fix Co-authored-by: Masahiro Masuda --- include/tvm/relay/attrs/vision.h | 11 +- python/tvm/relay/frontend/onnx.py | 238 +---------- python/tvm/relay/op/op_attrs.py | 5 + python/tvm/relay/op/strategy/cuda.py | 12 + python/tvm/relay/op/strategy/generic.py | 28 +- python/tvm/relay/op/vision/_vision.py | 19 + python/tvm/relay/op/vision/nms.py | 48 +++ python/tvm/topi/cuda/__init__.py | 2 +- python/tvm/topi/cuda/nms.py | 423 ++++++++++++------- python/tvm/topi/cuda/scan.py | 4 +- python/tvm/topi/cuda/sort.py | 27 +- python/tvm/topi/cuda/vision.py | 2 +- python/tvm/topi/vision/nms.py | 193 ++++++++- python/tvm/topi/vision/nms_util.py | 282 +++++++++++++ src/relay/op/vision/nms.cc | 63 +++ tests/python/frontend/onnx/test_forward.py | 2 +- tests/python/relay/test_op_level5.py | 97 ++++- tests/python/topi/python/test_topi_vision.py | 112 +++++ 18 files changed, 1171 insertions(+), 397 deletions(-) create mode 100644 python/tvm/topi/vision/nms_util.py diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index 4a96d391430ee..005b900d5d440 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -86,8 +86,6 @@ struct GetValidCountsAttrs : public tvm::AttrsNode { /*! \brief Attributes used in non_maximum_suppression operator */ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode { - Optional max_output_size; - Optional iou_threshold; bool force_suppress; int top_k; int coord_start; @@ -97,8 +95,6 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode { + TVM_DECLARE_ATTRS(AllClassNonMaximumSuppressionAttrs, + "relay.attrs.AllClassNonMaximumSuppressionAttrs") {} +}; + /*! \brief Attributes used in roi_align operators */ struct ROIAlignAttrs : public tvm::AttrsNode { Array pooled_size; diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 14d3c40d40b3c..ffeb0dd731713 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2456,17 +2456,6 @@ class NonMaxSuppression(OnnxOpConverter): @classmethod def _impl_v10(cls, inputs, attr, params): - """ - High level note: ONNX implements what TF calls combined_non_max_suppression - It passes in scores for each box for every class in the output and expects boxes to be - analyzed for each class independently - - It also asks for the data to be returned in a particular format. - - To support these, we implement a series of lops: - The first loop splits over class number, performs NMS, and collects the outputs. - The second (nested) loop takes the outputs and transforms them into the format ONNX wants - """ # Get parameter values boxes = inputs[0] scores = inputs[1] @@ -2474,8 +2463,6 @@ def _impl_v10(cls, inputs, attr, params): iou_threshold = inputs[3] score_threshold = inputs[4] - dtype = infer_type(boxes).checked_type.dtype - if "center_point_box" in attr: if attr["center_point_box"] != 0: raise NotImplementedError( @@ -2498,226 +2485,15 @@ def conditionally_squeeze_scalar(x): iou_threshold = conditionally_squeeze_scalar(iou_threshold) score_threshold = conditionally_squeeze_scalar(score_threshold) - ## prepare utility constants - zero = _op.const(np.array([0]), dtype="int64") - one = _op.const(np.array([1]), dtype="int64") - two = _op.const(np.array([2]), dtype="int64") - three = _op.const(np.array([3]), dtype="int64") - three_ones = _op.const(np.array([1, 1, 1]), dtype="int64") - four_ones = _op.const(np.array([1, 1, 1, 1]), dtype="int64") - - ## First loop: split by class and perform NMS - # Create Loop Vars - i = _expr.var("i", shape=(1,), dtype="int64") - scores_var = _expr.var("scores_var", shape=(_ty.Any(), _ty.Any(), _ty.Any()), dtype=dtype) - boxes_var = _expr.var("boxes_var", shape=(_ty.Any(), _ty.Any(), 4), dtype=dtype) - max_output_boxes_per_class_var = _expr.var( - "max_output_boxes_per_class_var", shape=(), dtype="int64" - ) - iou_threshold_var = _expr.var("iou_threshold_var", shape=(), dtype="float32") - score_threshold_var = _expr.var("score_threshold_var", shape=(), dtype="float32") - B = _expr.var("B", shape=(1,), dtype="int64") - C = _expr.var("C", shape=(1,), dtype="int64") - S = _expr.var("S", shape=(1,), dtype="int64") - # Outputs of first loop should be padded nms values shape (B, C, S, 3) - onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64") - # and sizes of valid outputs, shape (B, C, 1) - nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 1), dtype="int64") - - def _first_cond( - i, - scores, - boxes, - B, - C, - S, - max_output_boxes_per_class, - iou_threshold, - score_threshold, - onnx_out, - nms_size_out, - ): - # Loop over classes, end when i == C - return _op.take(_op.less(i, C), _expr.const(0)) - - def _first_body( - i, - scores, - boxes, - B, - C, - S, - max_output_boxes_per_class, - iou_threshold, - score_threshold, - onnx_out, - nms_size_out, - ): - # slice to get current class - begin = _op.concatenate([zero, i, zero], axis=0) - end = _op.concatenate([B, i + one, S], axis=0) - class_scores = _op.strided_slice(scores, begin, end, three_ones) - class_scores = _op.expand_dims(_op.squeeze(class_scores, [1]), -1, 1) - # combine scores and boxes - data = _op.concatenate([class_scores, boxes], axis=-1) - - # get valid counts - ct, data, indices = _op.vision.get_valid_counts( - data, score_threshold=score_threshold, id_index=-1, score_index=0 - ) - # reason why using get_valid_counts is for inference performance - # ONNX NMS doesn't have parameter top_k - top_k = -1 - # ONNX doesn't have class id for nms input - score_index = 0 - # perform nms on current class - nms_ret = _op.vision.non_max_suppression( - data=data, - valid_count=ct, - indices=indices, - max_output_size=max_output_boxes_per_class, - iou_threshold=iou_threshold, - force_suppress=True, - top_k=top_k, - coord_start=1, - score_index=score_index, - id_index=-1, - return_indices=True, - invalid_to_bottom=False, - ) - # partially prepare ONNX output format by labeling batch_num, class_id - nms_padded_out = _op.expand_dims(nms_ret[0], -1, 1) - batch_num = _op.expand_dims(_op.arange(_op.squeeze(B, [0]), dtype="int64"), -1, 1) - batch_num = _op.broadcast_to(batch_num, shape_of(nms_ret[0], dtype="int64")) - batch_num = _op.expand_dims(batch_num, -1, 1) - class_num = _op.broadcast_to(i, shape_of(nms_padded_out, dtype="int64")) - new_onnx_out = _op.concatenate( - [batch_num, class_num, _op.cast(nms_padded_out, "int64")], -1 - ) - new_onnx_out = _op.expand_dims(new_onnx_out, 1, 1) - # store valid nms outputs for this class - nms_size = _op.cast(nms_ret[1], "int64") - nms_size = _op.expand_dims(nms_size, 1, 1) - return [ - i + one, - scores, - boxes, - B, - C, - S, - max_output_boxes_per_class, - iou_threshold, - score_threshold, - _op.concatenate([onnx_out, new_onnx_out], axis=1), - _op.concatenate([nms_size_out, nms_size], axis=1), - ] - - # create the first loop - first_loop = _loops.while_loop( - _first_cond, - [ - i, - scores_var, - boxes_var, - B, - C, - S, - max_output_boxes_per_class_var, - iou_threshold_var, - score_threshold_var, - onnx_out, - nms_size_out, - ], - _first_body, - ) - - ## Second loop slices outputs of the first loop for valid boxes and - ## concats in the order ONNX wants - # Second inner Loop Vars - i = _expr.var("i", shape=(1,), dtype="int64") - j = _expr.var("j", shape=(1,), dtype="int64") - B = _expr.var("B", shape=(1,), dtype="int64") - C = _expr.var("C", shape=(1,), dtype="int64") - # Outputs of first loop should be padded nms values shape (B, C, 3) - onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64") - # and sizes of valid outputs, shape (B, C, 1) - nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 1), dtype="int64") - out = _expr.var("out", shape=(_ty.Any(), 3), dtype="int64") - - def _inner_cond(i, j, C, onnx_out, nms_size, out): - # inner loop over number of classes - return _op.take(_op.less(j, C), _expr.const(0)) - - def _inner_body(i, j, C, onnx_out, nms_size, out): - # slice to get current batch and class for valid box indicator - start = _op.concatenate([i, j + one, zero], axis=0) - end = _op.concatenate([i + one, j + two, one], axis=0) - num_valid_boxes = _op.reshape(_op.strided_slice(nms_size, start, end, three_ones), [1]) - # slice to get current batch, class, and valid outputs - start = _op.concatenate([i, j + one, zero, zero], axis=0) - end = _op.concatenate([i + one, j + two, num_valid_boxes, three], axis=0) - new_out = _op.squeeze(_op.strided_slice(onnx_out, start, end, four_ones), [0, 1]) - return i, j + one, C, onnx_out, nms_size, _op.concatenate([out, new_out], axis=0) - - inner_loop = _loops.while_loop( - _inner_cond, [i, j, C, onnx_out, nms_size_out, out], _inner_body + nms_out = _op.vision.all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold ) - # Second Outer Loop Vars - i = _expr.var("i", shape=(1,), dtype="int64") - j = _expr.var("j", shape=(1,), dtype="int64") - B = _expr.var("B", shape=(1,), dtype="int64") - C = _expr.var("C", shape=(1,), dtype="int64") - # Outputs of first loop should be padded nms values shape (B, C, 3) - onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64") - # and sizes of valid outputs, shape (B, C, 1) - nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 1), dtype="int64") - out = _expr.var("out", shape=(_ty.Any(), 3), dtype="int64") - - def _outer_cond(i, B, C, onnx_out, nms_size_out, out): - # Outer loop is over batch size - return _op.take(_op.less(i, B), _expr.const(0)) - - def _outer_body(i, B, C, onnx_out, nms_size_out, out): - # Outer loop just calls inner loop - init_count = _op.const(np.array([0]), dtype="int64") - inner_loop_vals = inner_loop(i, init_count, C, onnx_out, nms_size_out, out) - return i + one, B, C, onnx_out, nms_size_out, _expr.TupleGetItem(inner_loop_vals, 5) - - # Create the second loop - outer_loop = _loops.while_loop( - _outer_cond, [i, B, C, onnx_out, nms_size_out, out], _outer_body - ) - - # Call the first loop, perform NMS - B, C, S = _op.split(shape_of(scores, dtype="int64"), 3) - init_count = _op.const(np.array([0]), dtype="int64") - init_onnx_out = _op.const([1], dtype="int64") - init_onnx_out = _op.broadcast_to(init_onnx_out, _op.concatenate([B, one, S, three], 0)) - init_nms_size_out = _op.const([1], dtype="int64") - init_nms_size_out = _op.broadcast_to(init_nms_size_out, _op.concatenate([B, one, one], 0)) - loop_vals = first_loop( - init_count, - scores, - boxes, - B, - C, - S, - max_output_boxes_per_class, - iou_threshold, - score_threshold, - init_onnx_out, - init_nms_size_out, - ) - onnx_output = _expr.TupleGetItem(loop_vals, 9) - nms_size_output = _expr.TupleGetItem(loop_vals, 10) - - # Call the second loop, rework outputs into correct form - init_count = _op.const(np.array([0]).astype("int64"), dtype="int64") - init_out = _op.const(np.array([1, 1, 1]).reshape([1, 3]).astype("int64"), dtype="int64") - loop_vals = outer_loop(init_count, B, C, onnx_output, nms_size_output, init_out) - loop_out = _expr.TupleGetItem(loop_vals, 5) - return _op.strided_slice(loop_out, [1, 0], shape_of(loop_out), [1, 1]) + three = _op.const(np.array([3]), dtype="int64") + begin = _op.const(np.array([0, 0]), dtype="int64") + end = _op.concatenate([nms_out[1], three], axis=0) + strides = _op.const(np.array([1, 1]), dtype="int64") + return _op.strided_slice(nms_out[0], begin, end, strides) class ATen(OnnxOpConverter): diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 41076817b3749..4cc6e0f26b917 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -304,6 +304,11 @@ class NonMaximumSuppressionAttrs(Attrs): """Attributes for vision.non_maximum_suppression""" +@tvm._ffi.register_object("relay.attrs.AllClassNonMaximumSuppressionAttrs") +class AllClassNonMaximumSuppressionAttrs(Attrs): + """Attributes for vision.all_classnon_maximum_suppression""" + + @tvm._ffi.register_object("relay.attrs.ROIAlignAttrs") class ROIAlignAttrs(Attrs): """Attributes for vision.roi_align""" diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 61ab421427322..e5aa8aa106207 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -955,6 +955,18 @@ def nms_strategy_cuda(attrs, inputs, out_type, target): return strategy +@all_class_nms_strategy.register(["cuda", "gpu"]) +def all_class_nms_strategy_cuda(attrs, inputs, out_type, target): + """all class nms cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_all_class_nms(topi.cuda.all_class_non_max_suppression), + wrap_topi_schedule(topi.cuda.schedule_nms), + name="all_class_nms.cuda", + ) + return strategy + + @roi_align_strategy.register(["cuda", "gpu"]) def roi_align_strategy_cuda(attrs, inputs, out_type, target): """roi_align cuda strategy""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 845995e6ace4a..4c25255fd7b38 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1002,10 +1002,6 @@ def wrap_compute_nms(topi_compute): def _compute_nms(attrs, inputs, out_type): max_output_size = inputs[3] iou_threshold = inputs[4] - if attrs.max_output_size is not None: - max_output_size = attrs.max_output_size - if attrs.iou_threshold is not None: - iou_threshold = get_const_float(attrs.iou_threshold) return_indices = bool(get_const_int(attrs.return_indices)) force_suppress = bool(get_const_int(attrs.force_suppress)) top_k = get_const_int(attrs.top_k) @@ -1060,6 +1056,30 @@ def nms_strategy(attrs, inputs, out_type, target): return strategy +def wrap_compute_all_class_nms(topi_compute): + """wrap all class nms topi compute""" + + def _compute_nms(attrs, inputs, out_type): + max_output_size = inputs[2] + iou_threshold = inputs[3] + score_threshold = inputs[4] + return topi_compute(inputs[0], inputs[1], max_output_size, iou_threshold, score_threshold) + + return _compute_nms + + +@override_native_generic_func("all_class_non_max_suppression_strategy") +def all_class_nms_strategy(attrs, inputs, out_type, target): + """all class nms generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_all_class_nms(topi.vision.all_class_non_max_suppression), + wrap_topi_schedule(topi.generic.schedule_nms), + name="all_class_nms.generic", + ) + return strategy + + # roi_align def wrap_compute_roi_align(topi_compute): """wrap roi_align topi compute""" diff --git a/python/tvm/relay/op/vision/_vision.py b/python/tvm/relay/op/vision/_vision.py index 9c8c853fa3d20..7a31bce5ad49e 100644 --- a/python/tvm/relay/op/vision/_vision.py +++ b/python/tvm/relay/op/vision/_vision.py @@ -45,6 +45,9 @@ reg.register_strategy("vision.non_max_suppression", strategy.nms_strategy) reg.register_pattern("vision.non_max_suppression", OpPattern.OPAQUE) +reg.register_strategy("vision.all_class_non_max_suppression", strategy.all_class_nms_strategy) +reg.register_pattern("vision.all_class_non_max_suppression", OpPattern.OPAQUE) + @script def _get_valid_counts_shape_func(data_shape): @@ -85,6 +88,22 @@ def nms_shape_func(attrs, inputs, _): return [topi.math.identity(inputs[0])] +@script +def _all_class_nms_shape_func(boxes_shape, scores_shape): + out_shape = output_tensor((2,), "int64") + count_shape = output_tensor((1,), "int64") + + out_shape[0] = boxes_shape[0] * scores_shape[1] * boxes_shape[1] + out_shape[1] = 3 + count_shape[0] = int64(1) + return out_shape, count_shape + + +@reg.register_shape_func("vision.all_class_non_max_suppression", False) +def all_class_nms_shape_func(attrs, inputs, _): + return _all_class_nms_shape_func(inputs[0], inputs[1]) + + @script def _roi_align_shape_func_nchw(data_shape, rois_shape, pooled_size): out = output_tensor((4,), "int64") diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index 0a3df40b99dfd..3f829e0b1cc7d 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -149,3 +149,51 @@ def non_max_suppression( if return_indices: return expr.TupleWrapper(out, 2) return out + + +def all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class=-1, iou_threshold=-1.0, score_threshold=-1.0 +): + """Non-maximum suppression operator for object detection, corresponding to ONNX + NonMaxSuppression and TensorFlow combined_non_max_suppression. + NMS is performed for each class separately. + + Parameters + ---------- + boxes : relay.Expr + 3-D tensor with shape (batch_size, num_boxes, 4) + + scores: relay.Expr + 3-D tensor with shape (batch_size, num_classes, num_boxes) + + max_output_boxes_per_class : int or relay.Expr, optional + The maxinum number of output selected boxes per class + + iou_threshold : float or relay.Expr, optionaIl + IoU test threshold + + score_threshold : float or relay.Expr, optional + Score threshold to filter out low score boxes early + + Returns + ------- + out : relay.Tuple + The output is a relay.Tuple of two tensors, the first is `indices` of size + `(batch_size * num_class* num_boxes , 3)` and the second is a scalar tensor + `num_total_detection` of shape `(1,)` representing the total number of selected boxes. + Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come first, + in descending of scores, followed by boxes from batch 0, class 1 etc. Out of + `batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection` + rows are valid. + """ + if not isinstance(max_output_boxes_per_class, expr.Expr): + max_output_boxes_per_class = expr.const(max_output_boxes_per_class, "int32") + if not isinstance(iou_threshold, expr.Expr): + iou_threshold = expr.const(iou_threshold, "float32") + if not isinstance(score_threshold, expr.Expr): + score_threshold = expr.const(score_threshold, "float32") + + out = _make.all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold + ) + return expr.TupleWrapper(out, 2) diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index c2f55668d2e23..4d838db8bfba7 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -43,7 +43,7 @@ from .batch_matmul_tensorcore import * from .vision import * from .ssd import * -from .nms import get_valid_counts, non_max_suppression +from .nms import get_valid_counts, non_max_suppression, all_class_non_max_suppression from .rcnn import * from .scatter import * from .sort import * diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index c83dae0d3b96d..2789452cc10be 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -25,6 +25,14 @@ from .sort import argsort, argsort_thrust from .scan import exclusive_scan from ..utils import ceil_div +from ..math import cast +from ..transform import reshape +from ..vision.nms_util import ( + calculate_overlap, + binary_search, + collect_selected_indices, + run_all_class_nms, +) def cuda_atomic_add_rule(op): @@ -265,6 +273,97 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): return [valid_count, out, out_indices] +def _nms_loop( + ib, + batch_size, + top_k, + iou_threshold, + max_output_size, + valid_count, + on_new_valid_box_func, + on_new_invalidated_box_func, + needs_bbox_check_func, + calc_overlap_func, + out_scores, + num_valid_boxes, +): + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + + with ib.new_scope(): + nthread_by = batch_size + nthread_tx = max_threads + + # Some cuda architectures have smaller limit of 32K for cudaDevAttrMaxRegistersPerBlock + # vs 64K for most GPUs. Since this kernel uses many registers (around 35), the limit will + # be exceeded with 1024 threads. + target = tvm.target.Target.current(allow_none=False) + if target.kind.name == "cuda": + if nvcc.get_target_compute_version(target) in ["3.2", "5.3", "6.2"]: + nthread_tx = 512 + + by = te.thread_axis("blockIdx.y") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(tx, "thread_extent", nthread_tx) + + num_valid_boxes_local = ib.allocate( + "int32", (1,), name="num_valid_boxes_local", scope="local" + ) + num_valid_boxes_local[0] = 0 + + def nms_inner_loop(ib, i, j, nkeep): + # The box j is valid, invalidate other boxes that overlap with j above iou_threshold + on_new_valid_box_func(ib, tx, num_valid_boxes_local[0], i, j) + num_valid_boxes_local[0] += 1 + + num_iter_per_thread = ceil_div(nkeep - (j + 1), nthread_tx) + + with ib.for_range(0, num_iter_per_thread, name="_k") as _k: + k = j + 1 + _k * nthread_tx + tx + + with ib.if_scope( + tvm.tir.all( + k < nkeep, + out_scores[i, k] > 0, # is the box k still valid? + needs_bbox_check_func(i, j, k), + ) + ): + iou = calc_overlap_func(i, j, k) + + with ib.if_scope(iou >= iou_threshold): + # invalidate the box k + out_scores[i, k] = -1.0 + on_new_invalidated_box_func(i, k) + + ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) + + i = by + + nkeep = if_then_else(tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]) + max_output_size = if_then_else(max_output_size > 0, max_output_size, nkeep) + + with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): + # Apply nms + # No need to do more iteration if we have already reached max_output_size boxes + box_idx = ib.allocate("int32", (1,), name="box_idx", scope="local") + box_idx[0] = 0 + with ib.while_loop( + tvm.tir.all(box_idx[0] < nkeep, num_valid_boxes_local[0] < max_output_size) + ): + # Proceed to the inner loop if the box with id box_idx is still valid + with ib.if_scope(out_scores[i, box_idx[0]] > -1.0): + nms_inner_loop(ib, i, box_idx[0], nkeep) + box_idx[0] += 1 + + with ib.if_scope(tx + 0 == 0): + num_valid_boxes[i] = num_valid_boxes_local[0] + + with ib.else_scope(): + num_valid_boxes[i] = 0 + + return ib.get() + + def nms_ir( data, sorted_index, @@ -352,43 +451,6 @@ def nms_ir( stmt : Stmt The result IR statement. """ - - def get_boundaries(output, box_idx): - l = tvm.te.min( - output[box_idx], - output[box_idx + 2], - ) - t = tvm.te.min( - output[box_idx + 1], - output[box_idx + 3], - ) - r = tvm.te.max( - output[box_idx], - output[box_idx + 2], - ) - b = tvm.te.max( - output[box_idx + 1], - output[box_idx + 3], - ) - return l, t, r, b - - def calculate_overlap(out_tensor, box_a_idx, box_b_idx): - """Calculate overlap of two boxes.""" - a_l, a_t, a_r, a_b = get_boundaries(out_tensor, box_a_idx) - b_l, b_t, b_r, b_b = get_boundaries(out_tensor, box_b_idx) - - # Overlapping width and height - w = tvm.te.max(0.0, tvm.te.min(a_r, b_r) - tvm.te.max(a_l, b_l)) - h = tvm.te.max(0.0, tvm.te.min(a_b, b_b) - tvm.te.max(a_t, b_t)) - - # Overlapping area - area = h * w - - # total area of the figure formed by box a and box b - # except for overlapping area - u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area - return tvm.tir.Select(u <= 0.0, 0.0, area / u) - batch_size = data.shape[0] num_anchors = data.shape[1] box_data_length = data.shape[2] @@ -492,105 +554,51 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): box_indices[i * num_anchors + j] = j - with ib.new_scope(): - nthread_by = batch_size - nthread_tx = max_threads - - # Some cuda architectures have smaller limit of 32K for cudaDevAttrMaxRegistersPerBlock - # vs 64K for most GPUs. Since this kernel uses many registers (around 35), the limit will - # be exceeded with 1024 threads. - target = tvm.target.Target.current(allow_none=False) - if target.kind.name == "cuda": - if nvcc.get_target_compute_version(target) in ["3.2", "5.3", "6.2"]: - nthread_tx = 512 - - by = te.thread_axis("blockIdx.y") - tx = te.thread_axis("threadIdx.x") - ib.scope_attr(by, "thread_extent", nthread_by) - ib.scope_attr(tx, "thread_extent", nthread_tx) - - i = by + if isinstance(max_output_size, int): + max_output_size = tvm.tir.const(max_output_size) + def calc_overlap(i, j, k): + offset_j = j * 4 + offset_k = k * 4 base_bbox_idx = i * num_anchors * 4 - num_valid_boxes_local = ib.allocate( - "int32", (1,), name="num_valid_boxes_local", scope="local" + return calculate_overlap( + out_bboxes, + base_bbox_idx + offset_j, + base_bbox_idx + offset_k, ) - num_valid_boxes_local[0] = 0 - nkeep = if_then_else(tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]) - - def nms_inner_loop(ib, j): - # The box j is valid, invalidate other boxes that overlap with j above iou_threshold - - # When return_indices is False, no need to populate box_indices - if return_indices: - with ib.if_scope(tx + 0 == 0): - orig_idx = sorted_index[i * num_anchors + j] - box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] - - num_valid_boxes_local[0] += 1 - offset_j = j * 4 - num_iter_per_thread = ceil_div(nkeep - (j + 1), nthread_tx) - - with ib.for_range(0, num_iter_per_thread, name="_k") as _k: - k = j + 1 + _k * nthread_tx + tx - offset_k = k * 4 - - with ib.if_scope( - tvm.tir.all( - k < nkeep, - out_scores[i, k] > 0, # is the box k still valid? - tvm.tir.any( - force_suppress > 0, - id_index < 0, - out_class_ids[i, k] == out_class_ids[i, j], - ), - ) - ): - iou = calculate_overlap( - out_bboxes, - base_bbox_idx + offset_j, - base_bbox_idx + offset_k, - ) - with ib.if_scope(iou >= iou_threshold): - # invalidate the box k - out_scores[i, k] = -1.0 - - if return_indices is False and id_index >= 0: - out_class_ids[i, k] = -1.0 - - ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) - - if isinstance(max_output_size, int): - max_output_size = tvm.tir.const(max_output_size) - - with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): - # Apply nms - with ib.if_scope(max_output_size > 0): - # No need to do more iteration if we have already reached max_output_size boxes - box_idx = ib.allocate("int32", (1,), name="box_idx", scope="local") - box_idx[0] = 0 - with ib.while_loop( - tvm.tir.all(box_idx[0] < nkeep, num_valid_boxes_local[0] < max_output_size) - ): - # Proceed to the inner loop if the box with id box_idx is still valid - with ib.if_scope(out_scores[i, box_idx[0]] > -1.0): - nms_inner_loop(ib, box_idx[0]) - box_idx[0] += 1 - - with ib.else_scope(): - with ib.for_range(0, nkeep, name="j") as j: - # Proceed to the inner loop if the box j is still valid - with ib.if_scope(out_scores[i, j] > -1.0): - nms_inner_loop(ib, j) - - with ib.if_scope(tx + 0 == 0): - num_valid_boxes[i] = num_valid_boxes_local[0] - - with ib.else_scope(): - num_valid_boxes[i] = 0 + def on_new_valid_box(ib, tid, num_current_valid_box, i, j): + # When return_indices is False, no need to populate box_indices + if return_indices: + with ib.if_scope(tid + 0 == 0): + orig_idx = sorted_index[i * num_anchors + j] + box_indices[i, num_current_valid_box] = indices[i, orig_idx] + + def on_new_invalidated_box(i, k): + if return_indices is False and id_index >= 0: + out_class_ids[i, k] = -1.0 + + def needs_bbox_check(i, j, k): + return tvm.tir.any( + force_suppress > 0, + id_index < 0, + out_class_ids[i, k] == out_class_ids[i, j], + ) - return ib.get() + return _nms_loop( + ib, + batch_size, + top_k, + iou_threshold, + max_output_size, + valid_count, + on_new_valid_box, + on_new_invalidated_box, + needs_bbox_check, + calc_overlap, + out_scores, + num_valid_boxes, + ) def _fetch_score_ir(data, score, axis): @@ -622,6 +630,16 @@ def _fetch_score_ir(data, score, axis): return ib.get() +def _dispatch_sort(scores, ret_type="indices"): + target = tvm.target.Target.current() + if target and ( + can_use_thrust(target, "tvm.contrib.thrust.sort") + or can_use_rocthrust(target, "tvm.contrib.thrust.sort") + ): + return argsort_thrust(scores, axis=1, is_ascend=False, dtype="int32", ret_type=ret_type) + return argsort(scores, axis=1, is_ascend=False, dtype="int32", ret_type=ret_type) + + def _get_sorted_indices(data, data_buf, score_index, score_shape): """Extract a 1D score tensor from the packed input and do argsort on it.""" score_buf = tvm.tir.decl_buffer(score_shape, data.dtype, "score_buf", data_alignment=8) @@ -639,17 +657,7 @@ def _get_sorted_indices(data, data_buf, score_index, score_shape): name="fetch_score", tag="fetch_score", ) - - target = tvm.target.Target.current() - if target and ( - can_use_thrust(target, "tvm.contrib.thrust.sort") - or can_use_rocthrust(target, "tvm.contrib.thrust.sort") - ): - sort_tensor = argsort_thrust(score_tensor, axis=1, is_ascend=False, dtype="int32") - else: - sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype="int32") - - return sort_tensor + return _dispatch_sort(score_tensor) def _run_nms( @@ -910,3 +918,134 @@ def non_max_suppression( score_index, id_index, ) + + +def _get_valid_box_count(scores, score_threshold): + batch_classes, num_boxes = scores.shape + + def searchsorted_ir(scores, valid_count): + ib = tvm.tir.ir_builder.create() + scores = ib.buffer_ptr(scores) + valid_count = ib.buffer_ptr(valid_count) + + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + + with ib.new_scope(): + ib.scope_attr(bx, "thread_extent", ceil_div(batch_classes, max_threads)) + ib.scope_attr(tx, "thread_extent", max_threads) + tid = bx * max_threads + tx + + with ib.if_scope(tid < batch_classes): + binary_search(ib, tid, num_boxes, scores, score_threshold, valid_count) + + return ib.get() + + scores_buf = tvm.tir.decl_buffer(scores.shape, scores.dtype, "scores_buf", data_alignment=8) + + return te.extern( + [(batch_classes,)], + [scores], + lambda ins, outs: searchsorted_ir(ins[0], outs[0]), + dtype=["int32"], + in_buffers=[scores_buf], + name="searchsorted", + tag="searchsorted", + ) + + +def _collect_selected_indices_ir(num_class, selected_indices, num_detections, row_offsets, out): + batch_classes, num_boxes = selected_indices.shape + + ib = tvm.tir.ir_builder.create() + + selected_indices = ib.buffer_ptr(selected_indices) + num_detections = ib.buffer_ptr(num_detections) + row_offsets = ib.buffer_ptr(row_offsets) + out = ib.buffer_ptr(out) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = ceil_div(num_boxes, nthread_tx) + nthread_by = batch_classes + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(by, "thread_extent", nthread_by) + + with ib.new_scope(): + idx = bx * nthread_tx + tx + idy = cast(by, "int64") + batch_id = idy // num_class + class_id = idy % num_class + with ib.if_scope(idx < num_detections[idy]): + out[row_offsets[idy] + idx, 0] = batch_id + out[row_offsets[idy] + idx, 1] = class_id + out[row_offsets[idy] + idx, 2] = cast(selected_indices[idy, idx], "int64") + + return ib.get() + + +def all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold +): + """Non-maximum suppression operator for object detection, corresponding to ONNX + NonMaxSuppression and TensorFlow combined_non_max_suppression. + NMS is performed for each class separately. + + Parameters + ---------- + boxes : tvm.te.Tensor + 3-D tensor with shape (batch_size, num_boxes, 4) + + scores: tvm.te.Tensor + 3-D tensor with shape (batch_size, num_classes, num_boxes) + + max_output_boxes_per_class : int or tvm.te.Tensor, optional + The maxinum number of output selected boxes per class + + iou_threshold : float or tvm.te.Tensor, optionaIl + IoU test threshold + + score_threshold : float or tvm.te.Tensor, optional + Score threshold to filter out low score boxes early + + Returns + ------- + out : [tvm.te.Tensor, tvm.te.Tensor] + The output is two tensors, the first is `indices` of size + `(batch_size * num_class* num_boxes , 3)` and the second is a scalar tensor + `num_total_detection` of shape `(1,)` representing the total number of selected + boxes. Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come + first, in descending of scores, followed by boxes from batch 0, class 1 etc. Out of + `batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection` + rows are valid. + """ + batch, num_class, num_boxes = scores.shape + + scores = reshape(scores, (batch * num_class, num_boxes)) + sorted_scores, sorted_indices = _dispatch_sort(scores, ret_type="both") + valid_count = _get_valid_box_count(sorted_scores, score_threshold) + + selected_indices, num_detections = run_all_class_nms( + boxes, + sorted_scores, + sorted_indices, + valid_count, + max_output_boxes_per_class, + iou_threshold, + _nms_loop, + ) + + row_offsets, num_total_detections = exclusive_scan( + num_detections, return_reduction=True, output_dtype="int64" + ) + + selected_indices = collect_selected_indices( + num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir + ) + + return [selected_indices, num_total_detections] diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 3240ebcd515c8..5d3798e3d27bb 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -81,7 +81,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i ib.scope_attr(bx, "thread_extent", batch_size) with ib.if_scope(bx < batch_size): if reduction is not None: - reduction[bx] = 0 + reduction[bx] = cast(identity_value, out_dtype) with ib.else_scope(): with ib.new_scope(): nthread_tx = max_threads @@ -393,7 +393,7 @@ def do_scan(data, output_dtype): lambda ins, outs: exclusive_scan_ir( ins[0], outs[0], outs[1], binop=binop, identity_value=identity_value ), - dtype=[data.dtype, output_dtype], + dtype=[output_dtype, output_dtype], in_buffers=[data_buf], name="exclusive_scan", tag="exclusive_scan_gpu", diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 5ebd3060a6bbf..93e4d3feccc79 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -739,7 +739,7 @@ def sort_thrust(data, axis=-1, is_ascend=1): return out -def argsort(data, axis=-1, is_ascend=1, dtype="float32"): +def argsort(data, axis=-1, is_ascend=1, dtype="float32", ret_type="indices"): """Performs sorting along the given axis and returns an array of indicies having same shape as an input array that index data in sorted order. @@ -757,6 +757,11 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"): dtype : string, optional DType of the output indices. + ret_type : string, optional + The return type [both, indices]. + "both": return both sorted data and indices. + "indices": return sorted indices only. + Returns ------- out : tvm.te.Tensor @@ -774,7 +779,7 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"): indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_swap_buf", data_alignment=8) - out = te.extern( + outs = te.extern( [data.shape, data.shape, data.shape, data.shape], [data], lambda ins, outs: sort_ir( @@ -789,16 +794,19 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"): out_buffers=[value_buf, indices_buf, value_swap_buf, indices_swap_buf], name="argsort_gpu", tag="argsort_gpu", - )[1] + ) if axis != ndim - 1: axes = swap(list(range(ndim)), axis) - out = transpose(out, axes) + outs = [transpose(out, axes) for out in outs] - return out + if ret_type == "indices": + return outs[1] + + return outs[0], outs[1] -def argsort_thrust(data, axis=-1, is_ascend=1, dtype="float32"): +def argsort_thrust(data, axis=-1, is_ascend=1, dtype="float32", ret_type="indices"): """Performs sorting along the given axis and returns an array of indicies having same shape as an input array that index data in sorted order. @@ -816,12 +824,17 @@ def argsort_thrust(data, axis=-1, is_ascend=1, dtype="float32"): dtype : string, optional DType of the output indices. + ret_type : string, optional + The return type [both, indices]. + "both": return both sorted data and indices. + "indices": return sorted indices only. + Returns ------- out : tvm.te.Tensor The output of this function. """ - return topk_thrust(data, 0, axis, "indices", is_ascend, dtype) + return topk_thrust(data, 0, axis, ret_type, is_ascend, dtype) def schedule_sort(outs): diff --git a/python/tvm/topi/cuda/vision.py b/python/tvm/topi/cuda/vision.py index 73b24deb35aea..88983ab89f76a 100644 --- a/python/tvm/topi/cuda/vision.py +++ b/python/tvm/topi/cuda/vision.py @@ -32,7 +32,7 @@ def _default_schedule(outs): scheduled_ops = [] def traverse(op): - if tag.is_broadcast(op.tag) or op.tag in ["bbox_score", "sorted_bbox"]: + if tag.is_injective(op.tag) or op.tag in ["bbox_score", "sorted_bbox"]: schedule_injective_from_existing(s, op.output(0)) for tensor in op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops: diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index 8be62a73c09ef..744c5ef7feda1 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -20,7 +20,18 @@ from tvm import te from tvm.te import hybrid -from ..sort import argsort +from tvm.tir import if_then_else + +from ..sort import sort, argsort +from ..math import cast +from ..transform import reshape +from .. import reduction +from ..scan import cumsum +from .nms_util import ( + binary_search, + collect_selected_indices, + run_all_class_nms, +) @hybrid.script @@ -597,3 +608,183 @@ def non_max_suppression( num_anchors=num_anchors, ) return out + + +def _nms_loop( + ib, + batch_size, + top_k, + iou_threshold, + max_output_size, + valid_count, + on_new_valid_box_func, + on_new_invalidated_box_func, + needs_bbox_check_func, + calc_overlap_func, + out_scores, + num_valid_boxes, +): + def nms_inner_loop(ib, i, j, nkeep, num_valid_boxes_local): + # The box j is valid, invalidate other boxes that overlap with j above iou_threshold + on_new_valid_box_func(ib, 0, num_valid_boxes_local[0], i, j) + num_valid_boxes_local[0] += 1 + + num_boxes_to_check = nkeep - (j + 1) + + with ib.for_range(0, num_boxes_to_check, name="_k", kind="parallel") as _k: + k = j + 1 + _k + + with ib.if_scope( + tvm.tir.all( + k < nkeep, + out_scores[i, k] > 0, # is the box k still valid? + needs_bbox_check_func(i, j, k), + ) + ): + iou = calc_overlap_func(i, j, k) + + with ib.if_scope(iou >= iou_threshold): + # invalidate the box k + out_scores[i, k] = -1.0 + on_new_invalidated_box_func(i, k) + + with ib.for_range(0, batch_size, name="i") as i: + nkeep = if_then_else(tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]) + max_output_size = if_then_else(max_output_size > 0, max_output_size, nkeep) + + with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): + num_valid_boxes_local = ib.allocate( + "int32", (1,), name="num_valid_boxes_local", scope="local" + ) + box_idx = ib.allocate("int32", (1,), name="box_idx", scope="local") + num_valid_boxes_local[0] = 0 + box_idx[0] = 0 + + # Apply nms + # No need to do more iteration if we have already reached max_output_size boxes + with ib.while_loop( + tvm.tir.all(box_idx[0] < nkeep, num_valid_boxes_local[0] < max_output_size) + ): + # Proceed to the inner loop if the box with id box_idx is still valid + with ib.if_scope(out_scores[i, box_idx[0]] > -1.0): + nms_inner_loop(ib, i, box_idx[0], nkeep, num_valid_boxes_local) + box_idx[0] += 1 + + num_valid_boxes[i] = num_valid_boxes_local[0] + + with ib.else_scope(): + num_valid_boxes[i] = 0 + + return ib.get() + + +def _get_valid_box_count(scores, score_threshold): + batch_classes, num_boxes = scores.shape + + def searchsorted_ir(scores, valid_count): + ib = tvm.tir.ir_builder.create() + scores = ib.buffer_ptr(scores) + valid_count = ib.buffer_ptr(valid_count) + + with ib.for_range(0, batch_classes, name="i", kind="parallel") as i: + binary_search(ib, i, num_boxes, scores, score_threshold, valid_count) + + return ib.get() + + scores_buf = tvm.tir.decl_buffer(scores.shape, scores.dtype, "scores_buf", data_alignment=8) + + return te.extern( + [(batch_classes,)], + [scores], + lambda ins, outs: searchsorted_ir(ins[0], outs[0]), + dtype=["int32"], + in_buffers=[scores_buf], + name="searchsorted", + tag="searchsorted", + ) + + +def _collect_selected_indices_ir(num_class, selected_indices, num_detections, row_offsets, out): + batch_classes, _ = selected_indices.shape + + ib = tvm.tir.ir_builder.create() + + selected_indices = ib.buffer_ptr(selected_indices) + num_detections = ib.buffer_ptr(num_detections) + row_offsets = ib.buffer_ptr(row_offsets) + out = ib.buffer_ptr(out) + + with ib.for_range(0, batch_classes, name="i", kind="parallel") as i: + i = cast(i, "int64") + batch_id = i // num_class + class_id = i % num_class + + with ib.for_range(0, num_detections[i], name="j") as j: + out[row_offsets[i] + j, 0] = batch_id + out[row_offsets[i] + j, 1] = class_id + out[row_offsets[i] + j, 2] = cast(selected_indices[i, j], "int64") + + return ib.get() + + +def all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold +): + """Non-maximum suppression operator for object detection, corresponding to ONNX + NonMaxSuppression and TensorFlow combined_non_max_suppression. + NMS is performed for each class separately. + + Parameters + ---------- + boxes : tvm.te.Tensor + 3-D tensor with shape (batch_size, num_boxes, 4) + + scores: tvm.te.Tensor + 3-D tensor with shape (batch_size, num_classes, num_boxes) + + max_output_boxes_per_class : int or tvm.te.Tensor, optional + The maxinum number of output selected boxes per class + + iou_threshold : float or tvm.te.Tensor, optionaIl + IoU test threshold + + score_threshold : float or tvm.te.Tensor, optional + Score threshold to filter out low score boxes early + + Returns + ------- + out : [tvm.te.Tensor, tvm.te.Tensor] + The output is two tensors, the first is `indices` of size + `(batch_size * num_class* num_boxes , 3)` and the second is a scalar tensor + `num_total_detection` of shape `(1,)` representing the total number of selected + boxes. Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come + first, in descending of scores, followed by boxes from batch 0, class 1 etc. Out of + `batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection` + rows are valid. + """ + batch, num_class, num_boxes = scores.shape + scores = reshape(scores, (batch * num_class, num_boxes)) + + sorted_scores = sort(scores, axis=1, is_ascend=False) + sorted_indices = argsort(scores, axis=1, is_ascend=False, dtype="int32") + valid_count = _get_valid_box_count(sorted_scores, score_threshold) + + selected_indices, num_detections = run_all_class_nms( + boxes, + sorted_scores, + sorted_indices, + valid_count, + max_output_boxes_per_class, + iou_threshold, + _nms_loop, + ) + + row_offsets = cumsum(num_detections, exclusive=True, dtype="int64") + + num_total_detections = reduction.sum(cast(num_detections, "int64"), axis=1) + + selected_indices = collect_selected_indices( + num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir + ) + + return [selected_indices, num_total_detections] diff --git a/python/tvm/topi/vision/nms_util.py b/python/tvm/topi/vision/nms_util.py new file mode 100644 index 0000000000000..1147b1687783d --- /dev/null +++ b/python/tvm/topi/vision/nms_util.py @@ -0,0 +1,282 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Common utilities used in Non-maximum suppression operators""" +import tvm +from tvm import te + + +def _get_boundaries(output, box_idx): + l = tvm.te.min( + output[box_idx], + output[box_idx + 2], + ) + t = tvm.te.min( + output[box_idx + 1], + output[box_idx + 3], + ) + r = tvm.te.max( + output[box_idx], + output[box_idx + 2], + ) + b = tvm.te.max( + output[box_idx + 1], + output[box_idx + 3], + ) + return l, t, r, b + + +def calculate_overlap(out_tensor, box_a_idx, box_b_idx): + """Calculate overlap of two boxes.""" + a_l, a_t, a_r, a_b = _get_boundaries(out_tensor, box_a_idx) + b_l, b_t, b_r, b_b = _get_boundaries(out_tensor, box_b_idx) + + # Overlapping width and height + w = tvm.te.max(0.0, tvm.te.min(a_r, b_r) - tvm.te.max(a_l, b_l)) + h = tvm.te.max(0.0, tvm.te.min(a_b, b_b) - tvm.te.max(a_t, b_t)) + + # Overlapping area + area = h * w + + # total area of the figure formed by box a and box b + # except for overlapping area + u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area + return tvm.tir.Select(u <= 0.0, 0.0, area / u) + + +def binary_search(ib, y, num_boxes, scores, score_threshold, out): + """Binary search for score_threshold on scores sorted in descending order""" + lo = ib.allocate("int32", (1,), name="lo", scope="local") + hi = ib.allocate("int32", (1,), name="hi", scope="local") + + lo[0] = 0 + hi[0] = num_boxes + + with ib.while_loop(lo[0] < hi[0]): + mid = (hi[0] + lo[0]) >> 1 + with ib.if_scope(scores[y, mid] > score_threshold): + lo[0] = mid + 1 + with ib.else_scope(): + hi[0] = mid + + out[y] = lo[0] + + +def collect_selected_indices(num_class, selected_indices, num_detections, row_offsets, ir): + """Collect selected indices from the core NMS loop into one linear output + + Parameters + ---------- + num_class : int + + selected_indices: tvm.te.Tensor + 2-D tensor with shape (batch_size * num_classes, num_boxes), representing the indices + of selected boxes by the core NMS loop. + + num_detections tvm.te.Tensor + 1-D tensor with shape (batch_size * num_classes,), representing + the number of boxes selected by the core NMS loop, per batch and class + + row_offsets tvm.te.Tensor + 1-D tensor with shape (batch_size * num_classes,), this should be the exclusive scan + of num_detections + + ir : function + A function to generate IR for CPU or GPU, see its usage in vision/nms.py and cuda/nms.py + + Returns + ------- + out : tvm.te.Tensor + The output is indices of size (batch_size * num_class* num_boxes , 3). + Rows of indices are ordered such that selected boxes from batch 0, class 0 come + first, in descending of scores, followed by boxes from batch 0, class 1 etc. + """ + batch_class, num_boxes = selected_indices.shape + + selected_indices_buf = tvm.tir.decl_buffer( + selected_indices.shape, selected_indices.dtype, "selected_indices_buf", data_alignment=8 + ) + num_detections_buf = tvm.tir.decl_buffer( + num_detections.shape, num_detections.dtype, "num_detections_buf", data_alignment=8 + ) + row_offsets_buf = tvm.tir.decl_buffer( + row_offsets.shape, row_offsets.dtype, "row_offsets_buf", data_alignment=8 + ) + + return te.extern( + [(batch_class * num_boxes, 3)], + [selected_indices, num_detections, row_offsets], + lambda ins, outs: ir(num_class, ins[0], ins[1], ins[2], outs[0]), + dtype=["int64"], + in_buffers=[selected_indices_buf, num_detections_buf, row_offsets_buf], + name="collect_indices", + tag="collect_indices", + ) + + +def _all_class_nms_ir( + boxes, + sorted_scores, + sorted_indices, + valid_count, + batch_class, + num_class, + num_anchors, + iou_threshold, + max_output_size_per_class, + box_indices, + num_valid_boxes, + nms_loop, +): + ib = tvm.tir.ir_builder.create() + boxes = ib.buffer_ptr(boxes) + sorted_scores = ib.buffer_ptr(sorted_scores) + sorted_indices = ib.buffer_ptr(sorted_indices) + valid_count = ib.buffer_ptr(valid_count) + box_indices = ib.buffer_ptr(box_indices) + num_valid_boxes = ib.buffer_ptr(num_valid_boxes) + + if isinstance(iou_threshold, float): + iou_threshold = tvm.tir.FloatImm("float32", iou_threshold) + + if isinstance(max_output_size_per_class, int): + max_output_size_per_class = tvm.tir.const(max_output_size_per_class) + + def calc_overlap(i, j, k): + offset_j = sorted_indices[i, j] * 4 + offset_k = sorted_indices[i, k] * 4 + batch_id = i // num_class + base_bbox_idx = batch_id * num_anchors * 4 + return calculate_overlap( + boxes, + base_bbox_idx + offset_j, + base_bbox_idx + offset_k, + ) + + def on_new_valid_box(ib, tid, num_current_valid_box, i, j): + with ib.if_scope(tid + 0 == 0): + box_indices[i, num_current_valid_box] = sorted_indices[i, j] + + def on_new_invalidated_box(*_): + pass + + def needs_bbox_check(*_): + return tvm.tir.const(True) + + return nms_loop( + ib, + batch_class, + tvm.tir.IntImm("int32", -1), # top_k + iou_threshold, + max_output_size_per_class, + valid_count, + on_new_valid_box, + on_new_invalidated_box, + needs_bbox_check, + calc_overlap, + sorted_scores, + num_valid_boxes, + ) + + +def run_all_class_nms( + boxes, + sorted_scores, + sorted_indices, + valid_count, + max_output_size_per_class, + iou_threshold, + nms_loop, +): + """The core all class NMS routine + + Parameters + ---------- + boxes : tvm.te.Tensor + 3-D tensor with shape (batch_size, num_boxes, 4) + + sorted_scores: tvm.te.Tensor + 2-D tensor with shape (batch_size * num_classes, num_boxes) + One of the outputs from argsort + + sorted_indices: tvm.te.Tensor + 2-D tensor with shape (batch_size * num_classes, num_boxes) + The other output from argsort + + valid_count: tvm.te.Tensor + 1-D tensor with shape (batch_size * num_classes,), representing + the number of boxes whose score is above score_threshold, per batch and class + + max_output_boxes_per_class : int or tvm.te.Tensor, optional + The maxinum number of output selected boxes per class + + iou_threshold : float or tvm.te.Tensor, optionaIl + IoU test threshold + + nms_loop : function + A core NMS loop, see its usage in vision/nms.py and cuda/nms.py + + Returns + ------- + out : [tvm.te.Tensor, tvm.te.Tensor] + The output is two tensors, the first is indices of size + (batch_size * num_class, num_boxes) and the second is a tensor + num_selected_boxes of shape (batch_size * num_class,) representing the total number of + selected boxes per batch and class. + """ + batch, num_boxes, _ = boxes.shape + batch_class = sorted_scores.shape[0] + num_class = batch_class // batch + + boxes_buf = tvm.tir.decl_buffer(boxes.shape, boxes.dtype, "boxes_buf", data_alignment=8) + sorted_scores_buf = tvm.tir.decl_buffer( + sorted_scores.shape, sorted_scores.dtype, "sorted_scores_buf", data_alignment=8 + ) + sorted_indices_buf = tvm.tir.decl_buffer( + sorted_indices.shape, sorted_indices.dtype, "sorted_indices_buf", data_alignment=8 + ) + valid_count_buf = tvm.tir.decl_buffer( + valid_count.shape, "int32", "valid_count_buf", data_alignment=4 + ) + + return te.extern( + [(batch_class, num_boxes), (1, batch_class)], + [boxes, sorted_scores, sorted_indices, valid_count], + lambda ins, outs: _all_class_nms_ir( + ins[0], # boxes + ins[1], # sorted_scores + ins[2], # sorted_indices + ins[3], # valid_count + batch_class, + num_class, + num_boxes, + iou_threshold, + max_output_size_per_class, + outs[0], # box_indices + outs[1], # num_selected_boxes + nms_loop, + ), + dtype=["int32", "int32"], + in_buffers=[ + boxes_buf, + sorted_scores_buf, + sorted_indices_buf, + valid_count_buf, + ], + name="all_class_nms", + tag="all_class_nms", + ) diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 9316fecddca7c..53cd71745d5b4 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -23,6 +23,7 @@ */ #include #include +#include namespace tvm { namespace relay { @@ -132,5 +133,67 @@ ignore class_id axis. .set_support_level(5) .add_type_rel("NMS", NMSRel); +TVM_REGISTER_NODE_TYPE(AllClassNonMaximumSuppressionAttrs); + +bool AllClassNMSRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 6); + const auto* boxes = types[0].as(); + if (boxes == nullptr) return false; + const auto* scores = types[1].as(); + if (scores == nullptr) return false; + + const auto& boxes_shape = boxes->shape; + const auto& scores_shape = scores->shape; + ICHECK_EQ(boxes_shape.size(), 3) << "Input boxes should be 3-D."; + ICHECK_EQ(scores_shape.size(), 3) << "Input scores count should be 3-D."; + + IndexExpr batch = boxes_shape[0]; + IndexExpr num_classes = scores_shape[1]; + IndexExpr num_boxes = boxes_shape[1]; + + IndexExpr num_total_boxes = Any(); + if (!batch.as() && !num_boxes.as()) { + num_total_boxes = batch * num_classes * num_boxes; + } + + // assign output type + std::vector fields; + std::vector oshape{num_total_boxes, 3}; + fields.push_back(TensorType(oshape, DataType::Int(64))); + std::vector countshape{1}; + fields.push_back(TensorType(countshape, DataType::Int(64))); + reporter->Assign(types[5], TupleType(Array(fields))); + return true; +} + +Expr MakeAllClassNMS(Expr boxes, Expr scores, Expr max_output_boxes_per_class, Expr iou_threshold, + Expr score_threshold) { + auto attrs = make_object(); + static const Op& op = Op::Get("vision.all_class_non_max_suppression"); + return Call(op, {boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold}, + Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.vision._make.all_class_non_max_suppression") + .set_body_typed(MakeAllClassNMS); + +RELAY_REGISTER_OP("vision.all_class_non_max_suppression") + .describe(R"doc(Non-maximum suppression operator for object detection, corresponding to ONNX + NonMaxSuppression and TensorFlow combined_non_max_suppression. + NMS is performed for each class separately +)doc" TVM_ADD_FILELINE) + .set_num_inputs(5) + .add_argument("boxes", "Tensor", "The input boxes in the format [batch, num_boxes, 4].") + .add_argument("scores", "Tensor", + "Scores for each box and class in the format [batch, num_classes, num_boxes].") + .add_argument("max_output_boxes_per_class", "Tensor", + "The maximum number of output boxes per class.") + .add_argument("iou_threshold", "Tensor", "The IoU threshold for box the overlap test.") + .add_argument("score_threshold", "Tensor", + "The score threshold to filter out low score boxes early.") + .set_support_level(5) + .add_type_rel("AllClassNMS", AllClassNMSRel); + } // namespace relay } // namespace tvm diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 8a63bac33923a..6d22b5afd0df4 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3579,7 +3579,7 @@ def verify_roi_align( # ONNX implementation of roi_align with max mode is incorrect, so we don't compare outputs here. -# @tvm.testing.uses_gpu +@tvm.testing.uses_gpu def test_non_max_suppression(): def verify_nms( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_dims diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 2d6c8b50fd371..466b1b19a582d 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -21,7 +21,6 @@ import tvm from tvm import te from tvm import relay -from tvm.relay import transform from tvm.relay.testing import run_infer_type import tvm.topi.testing import tvm.testing @@ -371,8 +370,6 @@ def verify_nms( ) if isinstance(z_indices, relay.expr.TupleWrapper): z_indices = z_indices.astuple() - assert "iou_threshold" in z.astext() - assert "iou_threshold" in z_indices.astext() zz = run_infer_type(z) zz_indices = run_infer_type(z_indices) assert zz.checked_type == relay.ty.TensorType(dshape, "float32") @@ -1364,6 +1361,99 @@ def verify_batch_to_space_nd(dshape, block_shape, crops): verify_batch_to_space_nd([8, 1, 3, 1], [2, 2], [[0, 0], [2, 0]]) +@tvm.testing.uses_gpu +def test_all_class_non_max_suppression(): + def verify_all_class_non_max_suppression( + boxes_np, + scores_np, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + expected_indices, + ): + boxes = relay.var("boxes", relay.ty.TensorType(boxes_np.shape, "float32")) + scores = relay.var("scores", relay.ty.TensorType(scores_np.shape, "float32")) + + out = relay.vision.all_class_non_max_suppression( + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + ) + + func = relay.Function([boxes, scores], out.astuple()) + func = run_infer_type(func) + + for target, dev in tvm.testing.enabled_targets(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, device=dev, target=target) + selected_indices, num_detections = intrp.evaluate(func)(boxes_np, scores_np) + tvm_res = selected_indices.asnumpy()[: num_detections.asnumpy()[0]] + np.testing.assert_equal(tvm_res, expected_indices) + + boxes = np.array( + [ + [ + [0.0, 0.0, 0.3, 0.3], + [0.0, 0.0, 0.4, 0.4], + [0.0, 0.0, 0.5, 0.5], + [0.5, 0.5, 0.9, 0.9], + [0.5, 0.5, 1.0, 1.0], + ], + [ + [0.0, 0.0, 0.3, 0.3], + [0.0, 0.0, 0.4, 0.4], + [0.5, 0.5, 0.95, 0.95], + [0.5, 0.5, 0.96, 0.96], + [0.5, 0.5, 1.0, 1.0], + ], + ] + ).astype("float32") + + scores = np.array( + [ + [[0.1, 0.2, 0.6, 0.3, 0.9], [0.1, 0.2, 0.6, 0.3, 0.9]], + [[0.1, 0.2, 0.6, 0.3, 0.9], [0.1, 0.2, 0.6, 0.3, 0.9]], + ] + ).astype("float32") + + max_output_boxes_per_class = 2 + iou_threshold = 0.8 + score_threshold = 0.0 + + expected = np.array( + [[0, 0, 4], [0, 0, 2], [0, 1, 4], [0, 1, 2], [1, 0, 4], [1, 0, 1], [1, 1, 4], [1, 1, 1]] + ) + + verify_all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, expected + ) + + boxes = np.array( + [ + [ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.1, 1.0, 1.1], + [0.0, -0.1, 1.0, 0.9], + [0.0, 10.0, 1.0, 11.0], + [0.0, 10.1, 1.0, 11.1], + [0.0, 100.0, 1.0, 101.0], + ] + ] + ).astype(np.float32) + scores = np.array([[[0.9, 0.75, 0.6, 0.95, 0.5, 0.3]]]).astype(np.float32) + max_output_boxes_per_class = 3 + iou_threshold = 0.5 + score_threshold = 0.4 + + expected = np.array([[0, 0, 3], [0, 0, 0]]) + + verify_all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, expected + ) + + if __name__ == "__main__": test_resize_infer_type() test_resize() @@ -1386,3 +1476,4 @@ def verify_batch_to_space_nd(dshape, block_shape, crops): test_affine_grid() test_grid_sample() test_space_to_batch_nd() + test_all_class_non_max_suppression() diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py index 7f8712c55fd19..d260262658648 100644 --- a/tests/python/topi/python/test_topi_vision.py +++ b/tests/python/topi/python/test_topi_vision.py @@ -65,6 +65,11 @@ "gpu": (topi.cuda.proposal, topi.cuda.schedule_proposal), } +_all_class_nms_implement = { + "generic": (topi.vision.all_class_non_max_suppression, topi.generic.schedule_nms), + "gpu": (topi.cuda.all_class_non_max_suppression, topi.cuda.schedule_nms), +} + def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): dtype = "float32" @@ -623,6 +628,112 @@ def test_proposal(): verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs) +def verify_all_class_non_max_suppression( + boxes_np, + scores_np, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + expected_indices, +): + dshape = boxes_np.shape + batch, num_boxes, _ = dshape + _, num_class, _ = scores_np.shape + boxes = te.placeholder(dshape, name="boxes") + scores = te.placeholder(scores_np.shape, dtype="float32", name="scores") + + def check_device(target): + dev = tvm.device(target, 0) + if not tvm.testing.device_enabled(target): + print("Skip because %s is not enabled" % target) + return + print("Running on target: %s" % target) + with tvm.target.Target(target): + fcompute, fschedule = tvm.topi.testing.dispatch(target, _all_class_nms_implement) + out = fcompute( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold + ) + s = fschedule(out) + + tvm_boxes = tvm.nd.array(boxes_np, dev) + tvm_scores = tvm.nd.array(scores_np, dev) + selected_indices = tvm.nd.array(np.zeros((batch * num_class * num_boxes, 3), "int64"), dev) + num_detections = tvm.nd.array(np.zeros((1,), "int64"), dev) + + f = tvm.build(s, [boxes, scores, out[0], out[1]], target) + f(tvm_boxes, tvm_scores, selected_indices, num_detections) + + tvm_res = selected_indices.asnumpy()[: num_detections.asnumpy()[0]] + np.testing.assert_equal(tvm_res, expected_indices) + + for target in ["llvm", "cuda", "opencl", "vulkan"]: + check_device(target) + + +@tvm.testing.uses_gpu +def test_all_class_non_max_suppression(): + boxes = np.array( + [ + [ + [0.0, 0.0, 0.3, 0.3], + [0.0, 0.0, 0.4, 0.4], + [0.0, 0.0, 0.5, 0.5], + [0.5, 0.5, 0.9, 0.9], + [0.5, 0.5, 1.0, 1.0], + ], + [ + [0.0, 0.0, 0.3, 0.3], + [0.0, 0.0, 0.4, 0.4], + [0.5, 0.5, 0.95, 0.95], + [0.5, 0.5, 0.96, 0.96], + [0.5, 0.5, 1.0, 1.0], + ], + ] + ).astype("float32") + + scores = np.array( + [ + [[0.1, 0.2, 0.6, 0.3, 0.9], [0.1, 0.2, 0.6, 0.3, 0.9]], + [[0.1, 0.2, 0.6, 0.3, 0.9], [0.1, 0.2, 0.6, 0.3, 0.9]], + ] + ).astype("float32") + + max_output_boxes_per_class = 2 + iou_threshold = 0.8 + score_threshold = 0.0 + + expected = np.array( + [[0, 0, 4], [0, 0, 2], [0, 1, 4], [0, 1, 2], [1, 0, 4], [1, 0, 1], [1, 1, 4], [1, 1, 1]] + ) + + verify_all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, expected + ) + + boxes = np.array( + [ + [ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.1, 1.0, 1.1], + [0.0, -0.1, 1.0, 0.9], + [0.0, 10.0, 1.0, 11.0], + [0.0, 10.1, 1.0, 11.1], + [0.0, 100.0, 1.0, 101.0], + ] + ] + ).astype(np.float32) + scores = np.array([[[0.9, 0.75, 0.6, 0.95, 0.5, 0.3]]]).astype(np.float32) + max_output_boxes_per_class = 3 + iou_threshold = 0.5 + score_threshold = 0.4 + + expected = np.array([[0, 0, 3], [0, 0, 0]]) + + verify_all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, expected + ) + + if __name__ == "__main__": test_get_valid_counts() test_multibox_prior() @@ -631,3 +742,4 @@ def test_proposal(): test_roi_pool() test_proposal() test_non_max_suppression() + test_all_class_non_max_suppression()