Skip to content

Commit

Permalink
[TOPI, Relay] A new NMS op variant for ONNX NMS / TF Combined NMS (ap…
Browse files Browse the repository at this point in the history
…ache#7796)

* initial import

* add c++ boilarplate

* add python boilarpolate

* update onnx frontend

* fixing

* update onnx frontend

* fix shape

* minor update

* fix

* fix shape func

* fix for no box

* more fix

* made things 64 bit

* int64 tweak

* max_output_size doesn't need to be a callback

* remove all_class_nms schedule

* minor simplify

* remove expand_dim

* refactoring

* simplify nms loop

* cpu all_class_nms stub

* updating ir for cpu

* working with cpu

* update cpu strategy, relay op also working

* fix cpplint

* fixing pylint

* enable gpu test for onnx nms

* tweak parallel

* pyformat and lint

* fix relay nms test

* doc update for cpp relay

* updating tests

* updated tests

* fix converting score_threshold to Expr

* update doc

* doc fix

Co-authored-by: Masahiro Masuda <masahi@129@gmail.com>
  • Loading branch information
2 people authored and Trevor Morris committed May 6, 2021
1 parent 1b9280c commit 19a1536
Show file tree
Hide file tree
Showing 18 changed files with 1,171 additions and 397 deletions.
11 changes: 7 additions & 4 deletions include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ struct GetValidCountsAttrs : public tvm::AttrsNode<GetValidCountsAttrs> {

/*! \brief Attributes used in non_maximum_suppression operator */
struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionAttrs> {
Optional<Integer> max_output_size;
Optional<FloatImm> iou_threshold;
bool force_suppress;
int top_k;
int coord_start;
Expand All @@ -97,8 +95,6 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionA
bool invalid_to_bottom;

TVM_DECLARE_ATTRS(NonMaximumSuppressionAttrs, "relay.attrs.NonMaximumSuppressionAttrs") {
TVM_ATTR_FIELD(max_output_size).describe("Max number of output valid boxes for each instance.");
TVM_ATTR_FIELD(iou_threshold).describe("Non-maximum suppression iou threshold.");
TVM_ATTR_FIELD(force_suppress)
.set_default(false)
.describe("Suppress all detections regardless of class_id.");
Expand All @@ -118,6 +114,13 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionA
}
};

/*! \brief Attributes used in non_maximum_suppression operator */
struct AllClassNonMaximumSuppressionAttrs
: public tvm::AttrsNode<AllClassNonMaximumSuppressionAttrs> {
TVM_DECLARE_ATTRS(AllClassNonMaximumSuppressionAttrs,
"relay.attrs.AllClassNonMaximumSuppressionAttrs") {}
};

/*! \brief Attributes used in roi_align operators */
struct ROIAlignAttrs : public tvm::AttrsNode<ROIAlignAttrs> {
Array<IndexExpr> pooled_size;
Expand Down
238 changes: 7 additions & 231 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2456,26 +2456,13 @@ class NonMaxSuppression(OnnxOpConverter):

@classmethod
def _impl_v10(cls, inputs, attr, params):
"""
High level note: ONNX implements what TF calls combined_non_max_suppression
It passes in scores for each box for every class in the output and expects boxes to be
analyzed for each class independently
It also asks for the data to be returned in a particular format.
To support these, we implement a series of lops:
The first loop splits over class number, performs NMS, and collects the outputs.
The second (nested) loop takes the outputs and transforms them into the format ONNX wants
"""
# Get parameter values
boxes = inputs[0]
scores = inputs[1]
max_output_boxes_per_class = inputs[2]
iou_threshold = inputs[3]
score_threshold = inputs[4]

dtype = infer_type(boxes).checked_type.dtype

if "center_point_box" in attr:
if attr["center_point_box"] != 0:
raise NotImplementedError(
Expand All @@ -2498,226 +2485,15 @@ def conditionally_squeeze_scalar(x):
iou_threshold = conditionally_squeeze_scalar(iou_threshold)
score_threshold = conditionally_squeeze_scalar(score_threshold)

## prepare utility constants
zero = _op.const(np.array([0]), dtype="int64")
one = _op.const(np.array([1]), dtype="int64")
two = _op.const(np.array([2]), dtype="int64")
three = _op.const(np.array([3]), dtype="int64")
three_ones = _op.const(np.array([1, 1, 1]), dtype="int64")
four_ones = _op.const(np.array([1, 1, 1, 1]), dtype="int64")

## First loop: split by class and perform NMS
# Create Loop Vars
i = _expr.var("i", shape=(1,), dtype="int64")
scores_var = _expr.var("scores_var", shape=(_ty.Any(), _ty.Any(), _ty.Any()), dtype=dtype)
boxes_var = _expr.var("boxes_var", shape=(_ty.Any(), _ty.Any(), 4), dtype=dtype)
max_output_boxes_per_class_var = _expr.var(
"max_output_boxes_per_class_var", shape=(), dtype="int64"
)
iou_threshold_var = _expr.var("iou_threshold_var", shape=(), dtype="float32")
score_threshold_var = _expr.var("score_threshold_var", shape=(), dtype="float32")
B = _expr.var("B", shape=(1,), dtype="int64")
C = _expr.var("C", shape=(1,), dtype="int64")
S = _expr.var("S", shape=(1,), dtype="int64")
# Outputs of first loop should be padded nms values shape (B, C, S, 3)
onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64")
# and sizes of valid outputs, shape (B, C, 1)
nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 1), dtype="int64")

def _first_cond(
i,
scores,
boxes,
B,
C,
S,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
onnx_out,
nms_size_out,
):
# Loop over classes, end when i == C
return _op.take(_op.less(i, C), _expr.const(0))

def _first_body(
i,
scores,
boxes,
B,
C,
S,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
onnx_out,
nms_size_out,
):
# slice to get current class
begin = _op.concatenate([zero, i, zero], axis=0)
end = _op.concatenate([B, i + one, S], axis=0)
class_scores = _op.strided_slice(scores, begin, end, three_ones)
class_scores = _op.expand_dims(_op.squeeze(class_scores, [1]), -1, 1)
# combine scores and boxes
data = _op.concatenate([class_scores, boxes], axis=-1)

# get valid counts
ct, data, indices = _op.vision.get_valid_counts(
data, score_threshold=score_threshold, id_index=-1, score_index=0
)
# reason why using get_valid_counts is for inference performance
# ONNX NMS doesn't have parameter top_k
top_k = -1
# ONNX doesn't have class id for nms input
score_index = 0
# perform nms on current class
nms_ret = _op.vision.non_max_suppression(
data=data,
valid_count=ct,
indices=indices,
max_output_size=max_output_boxes_per_class,
iou_threshold=iou_threshold,
force_suppress=True,
top_k=top_k,
coord_start=1,
score_index=score_index,
id_index=-1,
return_indices=True,
invalid_to_bottom=False,
)
# partially prepare ONNX output format by labeling batch_num, class_id
nms_padded_out = _op.expand_dims(nms_ret[0], -1, 1)
batch_num = _op.expand_dims(_op.arange(_op.squeeze(B, [0]), dtype="int64"), -1, 1)
batch_num = _op.broadcast_to(batch_num, shape_of(nms_ret[0], dtype="int64"))
batch_num = _op.expand_dims(batch_num, -1, 1)
class_num = _op.broadcast_to(i, shape_of(nms_padded_out, dtype="int64"))
new_onnx_out = _op.concatenate(
[batch_num, class_num, _op.cast(nms_padded_out, "int64")], -1
)
new_onnx_out = _op.expand_dims(new_onnx_out, 1, 1)
# store valid nms outputs for this class
nms_size = _op.cast(nms_ret[1], "int64")
nms_size = _op.expand_dims(nms_size, 1, 1)
return [
i + one,
scores,
boxes,
B,
C,
S,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
_op.concatenate([onnx_out, new_onnx_out], axis=1),
_op.concatenate([nms_size_out, nms_size], axis=1),
]

# create the first loop
first_loop = _loops.while_loop(
_first_cond,
[
i,
scores_var,
boxes_var,
B,
C,
S,
max_output_boxes_per_class_var,
iou_threshold_var,
score_threshold_var,
onnx_out,
nms_size_out,
],
_first_body,
)

## Second loop slices outputs of the first loop for valid boxes and
## concats in the order ONNX wants
# Second inner Loop Vars
i = _expr.var("i", shape=(1,), dtype="int64")
j = _expr.var("j", shape=(1,), dtype="int64")
B = _expr.var("B", shape=(1,), dtype="int64")
C = _expr.var("C", shape=(1,), dtype="int64")
# Outputs of first loop should be padded nms values shape (B, C, 3)
onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64")
# and sizes of valid outputs, shape (B, C, 1)
nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 1), dtype="int64")
out = _expr.var("out", shape=(_ty.Any(), 3), dtype="int64")

def _inner_cond(i, j, C, onnx_out, nms_size, out):
# inner loop over number of classes
return _op.take(_op.less(j, C), _expr.const(0))

def _inner_body(i, j, C, onnx_out, nms_size, out):
# slice to get current batch and class for valid box indicator
start = _op.concatenate([i, j + one, zero], axis=0)
end = _op.concatenate([i + one, j + two, one], axis=0)
num_valid_boxes = _op.reshape(_op.strided_slice(nms_size, start, end, three_ones), [1])
# slice to get current batch, class, and valid outputs
start = _op.concatenate([i, j + one, zero, zero], axis=0)
end = _op.concatenate([i + one, j + two, num_valid_boxes, three], axis=0)
new_out = _op.squeeze(_op.strided_slice(onnx_out, start, end, four_ones), [0, 1])
return i, j + one, C, onnx_out, nms_size, _op.concatenate([out, new_out], axis=0)

inner_loop = _loops.while_loop(
_inner_cond, [i, j, C, onnx_out, nms_size_out, out], _inner_body
nms_out = _op.vision.all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold
)

# Second Outer Loop Vars
i = _expr.var("i", shape=(1,), dtype="int64")
j = _expr.var("j", shape=(1,), dtype="int64")
B = _expr.var("B", shape=(1,), dtype="int64")
C = _expr.var("C", shape=(1,), dtype="int64")
# Outputs of first loop should be padded nms values shape (B, C, 3)
onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64")
# and sizes of valid outputs, shape (B, C, 1)
nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 1), dtype="int64")
out = _expr.var("out", shape=(_ty.Any(), 3), dtype="int64")

def _outer_cond(i, B, C, onnx_out, nms_size_out, out):
# Outer loop is over batch size
return _op.take(_op.less(i, B), _expr.const(0))

def _outer_body(i, B, C, onnx_out, nms_size_out, out):
# Outer loop just calls inner loop
init_count = _op.const(np.array([0]), dtype="int64")
inner_loop_vals = inner_loop(i, init_count, C, onnx_out, nms_size_out, out)
return i + one, B, C, onnx_out, nms_size_out, _expr.TupleGetItem(inner_loop_vals, 5)

# Create the second loop
outer_loop = _loops.while_loop(
_outer_cond, [i, B, C, onnx_out, nms_size_out, out], _outer_body
)

# Call the first loop, perform NMS
B, C, S = _op.split(shape_of(scores, dtype="int64"), 3)
init_count = _op.const(np.array([0]), dtype="int64")
init_onnx_out = _op.const([1], dtype="int64")
init_onnx_out = _op.broadcast_to(init_onnx_out, _op.concatenate([B, one, S, three], 0))
init_nms_size_out = _op.const([1], dtype="int64")
init_nms_size_out = _op.broadcast_to(init_nms_size_out, _op.concatenate([B, one, one], 0))
loop_vals = first_loop(
init_count,
scores,
boxes,
B,
C,
S,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
init_onnx_out,
init_nms_size_out,
)
onnx_output = _expr.TupleGetItem(loop_vals, 9)
nms_size_output = _expr.TupleGetItem(loop_vals, 10)

# Call the second loop, rework outputs into correct form
init_count = _op.const(np.array([0]).astype("int64"), dtype="int64")
init_out = _op.const(np.array([1, 1, 1]).reshape([1, 3]).astype("int64"), dtype="int64")
loop_vals = outer_loop(init_count, B, C, onnx_output, nms_size_output, init_out)
loop_out = _expr.TupleGetItem(loop_vals, 5)
return _op.strided_slice(loop_out, [1, 0], shape_of(loop_out), [1, 1])
three = _op.const(np.array([3]), dtype="int64")
begin = _op.const(np.array([0, 0]), dtype="int64")
end = _op.concatenate([nms_out[1], three], axis=0)
strides = _op.const(np.array([1, 1]), dtype="int64")
return _op.strided_slice(nms_out[0], begin, end, strides)


class ATen(OnnxOpConverter):
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,11 @@ class NonMaximumSuppressionAttrs(Attrs):
"""Attributes for vision.non_maximum_suppression"""


@tvm._ffi.register_object("relay.attrs.AllClassNonMaximumSuppressionAttrs")
class AllClassNonMaximumSuppressionAttrs(Attrs):
"""Attributes for vision.all_classnon_maximum_suppression"""


@tvm._ffi.register_object("relay.attrs.ROIAlignAttrs")
class ROIAlignAttrs(Attrs):
"""Attributes for vision.roi_align"""
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,18 @@ def nms_strategy_cuda(attrs, inputs, out_type, target):
return strategy


@all_class_nms_strategy.register(["cuda", "gpu"])
def all_class_nms_strategy_cuda(attrs, inputs, out_type, target):
"""all class nms cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_all_class_nms(topi.cuda.all_class_non_max_suppression),
wrap_topi_schedule(topi.cuda.schedule_nms),
name="all_class_nms.cuda",
)
return strategy


@roi_align_strategy.register(["cuda", "gpu"])
def roi_align_strategy_cuda(attrs, inputs, out_type, target):
"""roi_align cuda strategy"""
Expand Down
28 changes: 24 additions & 4 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,10 +1002,6 @@ def wrap_compute_nms(topi_compute):
def _compute_nms(attrs, inputs, out_type):
max_output_size = inputs[3]
iou_threshold = inputs[4]
if attrs.max_output_size is not None:
max_output_size = attrs.max_output_size
if attrs.iou_threshold is not None:
iou_threshold = get_const_float(attrs.iou_threshold)
return_indices = bool(get_const_int(attrs.return_indices))
force_suppress = bool(get_const_int(attrs.force_suppress))
top_k = get_const_int(attrs.top_k)
Expand Down Expand Up @@ -1060,6 +1056,30 @@ def nms_strategy(attrs, inputs, out_type, target):
return strategy


def wrap_compute_all_class_nms(topi_compute):
"""wrap all class nms topi compute"""

def _compute_nms(attrs, inputs, out_type):
max_output_size = inputs[2]
iou_threshold = inputs[3]
score_threshold = inputs[4]
return topi_compute(inputs[0], inputs[1], max_output_size, iou_threshold, score_threshold)

return _compute_nms


@override_native_generic_func("all_class_non_max_suppression_strategy")
def all_class_nms_strategy(attrs, inputs, out_type, target):
"""all class nms generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_all_class_nms(topi.vision.all_class_non_max_suppression),
wrap_topi_schedule(topi.generic.schedule_nms),
name="all_class_nms.generic",
)
return strategy


# roi_align
def wrap_compute_roi_align(topi_compute):
"""wrap roi_align topi compute"""
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/relay/op/vision/_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
reg.register_strategy("vision.non_max_suppression", strategy.nms_strategy)
reg.register_pattern("vision.non_max_suppression", OpPattern.OPAQUE)

reg.register_strategy("vision.all_class_non_max_suppression", strategy.all_class_nms_strategy)
reg.register_pattern("vision.all_class_non_max_suppression", OpPattern.OPAQUE)


@script
def _get_valid_counts_shape_func(data_shape):
Expand Down Expand Up @@ -85,6 +88,22 @@ def nms_shape_func(attrs, inputs, _):
return [topi.math.identity(inputs[0])]


@script
def _all_class_nms_shape_func(boxes_shape, scores_shape):
out_shape = output_tensor((2,), "int64")
count_shape = output_tensor((1,), "int64")

out_shape[0] = boxes_shape[0] * scores_shape[1] * boxes_shape[1]
out_shape[1] = 3
count_shape[0] = int64(1)
return out_shape, count_shape


@reg.register_shape_func("vision.all_class_non_max_suppression", False)
def all_class_nms_shape_func(attrs, inputs, _):
return _all_class_nms_shape_func(inputs[0], inputs[1])


@script
def _roi_align_shape_func_nchw(data_shape, rois_shape, pooled_size):
out = output_tensor((4,), "int64")
Expand Down
Loading

0 comments on commit 19a1536

Please sign in to comment.