Skip to content

Commit

Permalink
update op definition
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 29, 2021
1 parent 8afbd30 commit 2fc5f1e
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 41 deletions.
11 changes: 10 additions & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,11 @@ def all_class_impl(
clip_boxes,
mod,
):
indices, num_detections = _op.vision.all_class_non_max_suppression(
(
selected_indices,
selected_scores,
num_detections,
) = _op.vision.all_class_non_max_suppression(
boxes,
scores,
max_output_boxes_per_class,
Expand All @@ -813,6 +817,11 @@ def all_class_impl(
max_total_size,
output_format="tensorflow",
)
topk_indices = _op.topk(selected_scores, k=max_total_size, axis=1, ret_type="indices")
topk_indices = _op.expand_dims(topk_indices, axis=0)
indices = _op.gather_nd(selected_indices, topk_indices, batch_dims=1)
num_detections = _op.minimum(num_detections, _op.const(max_total_size, dtype="int64"))

nmsed_box_indices = _op.take(indices, _op.const(1), axis=2)
nmsed_classes = _op.cast(_op.take(indices, _op.const(0), axis=2), "float32")
nmsed_boxes = _op.gather_nd(boxes, _op.expand_dims(nmsed_box_indices, axis=0), batch_dims=1)
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/relay/op/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,4 +209,8 @@ def all_class_non_max_suppression(
max_total_size,
output_format,
)
return expr.TupleWrapper(out, 2)

if output_format == "onnx":
return expr.TupleWrapper(out, 2)

return expr.TupleWrapper(out, 3)
7 changes: 1 addition & 6 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
collect_selected_indices,
collect_selected_indices_and_scores,
run_all_class_nms,
post_process_max_detections,
)


Expand Down Expand Up @@ -1134,8 +1133,4 @@ def all_class_non_max_suppression(
_collect_selected_indices_and_scores_ir,
)

topk_indices = topk(selected_scores, k=max_total_size, axis=1, ret_type="indices")[0]

return post_process_max_detections(
selected_indices, topk_indices, num_total_detections, max_total_size
)
return [selected_indices, selected_scores, num_total_detections]
12 changes: 2 additions & 10 deletions python/tvm/topi/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tvm.te import hybrid
from tvm.tir import if_then_else

from ..sort import argsort, topk
from ..sort import argsort
from ..math import cast
from ..transform import reshape, gather
from .. import reduction
Expand All @@ -32,7 +32,6 @@
collect_selected_indices,
collect_selected_indices_and_scores,
run_all_class_nms,
post_process_max_detections,
)


Expand Down Expand Up @@ -862,11 +861,4 @@ def all_class_non_max_suppression(
_collect_selected_indices_and_scores_ir,
)

topk_indices = topk(selected_scores, k=max_total_size, axis=1, ret_type="indices")

return post_process_max_detections(
selected_indices,
topk_indices,
num_total_detections,
max_total_size,
)
return [selected_indices, selected_scores, num_total_detections]
11 changes: 0 additions & 11 deletions python/tvm/topi/vision/nms_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
"""Common utilities used in Non-maximum suppression operators"""
import tvm
from tvm import te
from ..broadcast import minimum
from ..transform import gather_nd, expand_dims


def _get_boundaries(output, box_idx):
Expand Down Expand Up @@ -302,12 +300,3 @@ def run_all_class_nms(
name="all_class_nms",
tag="all_class_nms",
)


def post_process_max_detections(
selected_indices, topk_indices, num_total_detections, max_total_size
):
topk_indices = expand_dims(topk_indices, axis=0)
final_indices = gather_nd(selected_indices, topk_indices, batch_dims=1)
num_detections = minimum(num_total_detections, max_total_size)
return [final_indices, num_detections]
28 changes: 16 additions & 12 deletions src/relay/op/vision/nms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,29 +152,33 @@ bool AllClassNMSRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
IndexExpr num_classes = scores_shape[1];
IndexExpr num_boxes = boxes_shape[1];

IndexExpr num_total_boxes = Any();
if (!batch.as<AnyNode>() && !num_boxes.as<AnyNode>()) {
num_total_boxes = batch * num_classes * num_boxes;
}

const auto* param = attrs.as<AllClassNonMaximumSuppressionAttrs>();
CHECK(param);

std::vector<Type> fields;
if (param->output_format == "onnx") {
IndexExpr num_total_boxes = Any();
if (!batch.as<AnyNode>() && !num_boxes.as<AnyNode>()) {
num_total_boxes = batch * num_classes * num_boxes;
}
std::vector<IndexExpr> oshape{num_total_boxes, 3};
std::vector<IndexExpr> countshape{1};
std::vector<IndexExpr> counts_shape{1};
fields.push_back(TensorType(oshape, DataType::Int(64)));
fields.push_back(TensorType(countshape, DataType::Int(64)));
fields.push_back(TensorType(counts_shape, DataType::Int(64)));
} else {
ICHECK(param->max_total_size) << "max_total_size required for tf mode";
Integer max_total_size = param->max_total_size.value();
std::vector<IndexExpr> oshape{batch, max_total_size, 2};
std::vector<IndexExpr> countshape{batch};
fields.push_back(TensorType(oshape, DataType::Int(64)));
fields.push_back(TensorType(countshape, DataType::Int(64)));
IndexExpr num_total_boxes_per_batch = Any();
if (!num_boxes.as<AnyNode>()) {
num_total_boxes_per_batch = num_classes * num_boxes;
}
std::vector<IndexExpr> indices_shape{batch, num_total_boxes_per_batch, 2};
std::vector<IndexExpr> scores_shape{batch, num_total_boxes_per_batch};
std::vector<IndexExpr> counts_shape{batch};
fields.push_back(TensorType(indices_shape, DataType::Int(64)));
fields.push_back(TensorType(scores_shape, DataType::Float(32)));
fields.push_back(TensorType(counts_shape, DataType::Int(64)));
}

reporter->Assign(types[5], TupleType(Array<Type>(fields)));
return true;
}
Expand Down

0 comments on commit 2fc5f1e

Please sign in to comment.