diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index 005b900d5d440..45d01ff44fe79 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -117,6 +117,9 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode { + std::string output_format; + TVM_ATTR_FIELD(output_format).set_default("onnx").describe( + "Output format. onnx or tensorflow"); TVM_DECLARE_ATTRS(AllClassNonMaximumSuppressionAttrs, "relay.attrs.AllClassNonMaximumSuppressionAttrs") {} }; diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 86edb042b8b85..a99c1084f60e5 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -824,7 +824,7 @@ def _impl(inputs, attr, params, mod): # Transpose (batch_size, num_boxes, num_classes) -> (batch_size, num_classes, num_boxes) scores = _op.transpose(scores, [0, 2, 1]) indices, count = _op.vision.all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format="tensorflow" ) # Slice indices to count three = _op.const(np.array([3]), dtype="int64") diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index 3f829e0b1cc7d..703edb040bb74 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -152,7 +152,7 @@ def non_max_suppression( def all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class=-1, iou_threshold=-1.0, score_threshold=-1.0 + boxes, scores, max_output_boxes_per_class=-1, iou_threshold=-1.0, score_threshold=-1.0, output_format="onnx" ): """Non-maximum suppression operator for object detection, corresponding to ONNX NonMaxSuppression and TensorFlow combined_non_max_suppression. @@ -185,6 +185,7 @@ def all_class_non_max_suppression( in descending of scores, followed by boxes from batch 0, class 1 etc. Out of `batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection` rows are valid. + TODO(trvmorr): explain tf mode """ if not isinstance(max_output_boxes_per_class, expr.Expr): max_output_boxes_per_class = expr.const(max_output_boxes_per_class, "int32") @@ -194,6 +195,6 @@ def all_class_non_max_suppression( score_threshold = expr.const(score_threshold, "float32") out = _make.all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format ) return expr.TupleWrapper(out, 2) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 9a3b86d72b189..991a572cec2e9 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -989,7 +989,7 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro def all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format ): """Non-maximum suppression operator for object detection, corresponding to ONNX NonMaxSuppression and TensorFlow combined_non_max_suppression. @@ -1011,6 +1011,8 @@ def all_class_non_max_suppression( score_threshold : float or tvm.te.Tensor, optional Score threshold to filter out low score boxes early + + output_format : str Returns ------- @@ -1043,8 +1045,11 @@ def all_class_non_max_suppression( num_detections, return_reduction=True, output_dtype="int64" ) - selected_indices = collect_selected_indices( - num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir - ) + if output_format == "onnx": + selected_indices = collect_selected_indices( + num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir + ) + else: + selected_indices = reshape(selected_indices, (batch, num_class, num_boxes)) return [selected_indices, num_total_detections] diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index 744c5ef7feda1..455111086bec2 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -728,7 +728,7 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro def all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format ): """Non-maximum suppression operator for object detection, corresponding to ONNX NonMaxSuppression and TensorFlow combined_non_max_suppression. @@ -750,6 +750,8 @@ def all_class_non_max_suppression( score_threshold : float or tvm.te.Tensor, optional Score threshold to filter out low score boxes early + + output_format : TODO Returns ------- @@ -783,8 +785,9 @@ def all_class_non_max_suppression( num_total_detections = reduction.sum(cast(num_detections, "int64"), axis=1) - selected_indices = collect_selected_indices( - num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir - ) + if output_format == "onnx": + selected_indices = collect_selected_indices( + num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir + ) return [selected_indices, num_total_detections] diff --git a/python/tvm/topi/vision/nms_util.py b/python/tvm/topi/vision/nms_util.py index 1147b1687783d..60a7a344b2043 100644 --- a/python/tvm/topi/vision/nms_util.py +++ b/python/tvm/topi/vision/nms_util.py @@ -128,6 +128,59 @@ def collect_selected_indices(num_class, selected_indices, num_detections, row_of ) +def collect_selected_indices_tf(num_class, selected_indices, num_detections, row_offsets, ir): + """Collect selected indices from the core NMS loop into one linear output + + Parameters + ---------- + num_class : int + + selected_indices: tvm.te.Tensor + 2-D tensor with shape (batch_size * num_classes, num_boxes), representing the indices + of selected boxes by the core NMS loop. + + num_detections tvm.te.Tensor + 1-D tensor with shape (batch_size * num_classes,), representing + the number of boxes selected by the core NMS loop, per batch and class + + row_offsets tvm.te.Tensor + 1-D tensor with shape (batch_size * num_classes,), this should be the exclusive scan + of num_detections + + ir : function + A function to generate IR for CPU or GPU, see its usage in vision/nms.py and cuda/nms.py + + Returns + ------- + out : tvm.te.Tensor + The output is indices of size (batch_size * num_class* num_boxes , 3). + Rows of indices are ordered such that selected boxes from batch 0, class 0 come + first, in descending of scores, followed by boxes from batch 0, class 1 etc. + """ + batch_class, num_boxes = selected_indices.shape + + selected_indices_buf = tvm.tir.decl_buffer( + selected_indices.shape, selected_indices.dtype, "selected_indices_buf", data_alignment=8 + ) + num_detections_buf = tvm.tir.decl_buffer( + num_detections.shape, num_detections.dtype, "num_detections_buf", data_alignment=8 + ) + row_offsets_buf = tvm.tir.decl_buffer( + row_offsets.shape, row_offsets.dtype, "row_offsets_buf", data_alignment=8 + ) + + return te.extern( + [(batch_size, num_class * num_boxes, 2)], + [selected_indices, num_detections, row_offsets], + lambda ins, outs: ir(num_class, ins[0], ins[1], ins[2], outs[0]), + dtype=["int64"], + in_buffers=[selected_indices_buf, num_detections_buf, row_offsets_buf], + name="collect_indices", + tag="collect_indices", + ) + + + def _all_class_nms_ir( boxes, sorted_scores,