Skip to content

Commit

Permalink
Merge branch 'tmp' into all_class_nms_tf
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 30, 2021
2 parents 0fa8805 + da75b2a commit c86bcf4
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 25 deletions.
68 changes: 46 additions & 22 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,14 +796,14 @@ def _impl(inputs, attr, params, mod):
def _combined_nms():
def all_class_impl(
batch_size,
max_output_boxes_per_batch,
boxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
max_total_size,
clip_boxes,
mod,
):
(
selected_indices,
Expand All @@ -818,32 +818,55 @@ def all_class_impl(
max_total_size,
output_format="tensorflow",
)
nmsed_scores, topk_indices = _op.topk(
selected_scores, k=max_total_size, axis=1, ret_type="both"
box_range = _op.arange(
_op.const(0, dtype="int64"), _op.const(max_total_size, dtype="int64"), dtype="int64"
)
tile_batch_reps = (
_op.concatenate([batch_size, 1])
if isinstance(batch_size, tvm.tir.Any)
else _op.const([batch_size, 1])
)
box_range_2d = _op.tile(box_range, tile_batch_reps)
valid_mask = _op.cast(
_op.less(box_range_2d, _op.expand_dims(num_detections, axis=1)), "float32"
)
topk_indices = _op.expand_dims(topk_indices, axis=0)
indices = _op.gather_nd(selected_indices, topk_indices, batch_dims=1)
num_detections = _op.minimum(num_detections, _op.const(max_total_size, dtype="int64"))

def select_topk(do_zero_pad):
def true_branch():
arange = _op.arange(
_op.const(0, dtype="int64"),
_op.const(max_output_boxes_per_batch, dtype="int64"),
dtype="int64",
)
pad = _op.full(
_op.const(0, dtype="int64"), (max_total_size - max_output_boxes_per_batch,)
)
topk_indices = _op.tile(_op.concatenate([arange, pad], 0), tile_batch_reps)
nmsed_scores = _op.gather(selected_scores, 1, topk_indices)
nmsed_scores = nmsed_scores * valid_mask
return nmsed_scores, topk_indices

def false_branch():
return _op.topk(selected_scores, k=max_total_size, axis=1, ret_type="both")

# TODO(masahi): support dynamic num_boxes
# return _expr.If(do_zero_pad, true_branch(), false_branch())
return true_branch() if do_zero_pad else false_branch()

assert isinstance(
max_output_boxes_per_batch, int
), "dynamic number of boxes not supported yet."
nmsed_scores, topk_indices = select_topk(max_output_boxes_per_batch < max_total_size)

indices = _op.take(selected_indices, topk_indices, axis=1, batch_dims=1)
nmsed_box_indices = _op.take(indices, _op.const(1), axis=2)
nmsed_classes = _op.cast(_op.take(indices, _op.const(0), axis=2), "float32")
nmsed_boxes = _op.gather_nd(boxes, _op.expand_dims(nmsed_box_indices, axis=0), batch_dims=1)
nmsed_classes = _op.take(indices, _op.const(0), axis=2)
nmsed_boxes = _op.take(boxes, nmsed_box_indices, axis=1, batch_dims=1)

if clip_boxes:
nmsed_boxes = _op.maximum(nmsed_boxes, _expr.const(0, dtype="float32"))
nmsed_boxes = _op.minimum(nmsed_boxes, _expr.const(1, dtype="float32"))

# Fill in invalid entries with 0
box_range = _op.arange(
_op.const(0, dtype="int64"), _op.const(max_total_size, dtype="int64"), dtype="int64"
)
if isinstance(batch_size, tvm.tir.Any):
box_range_2d = _op.tile(box_range, _op.concatenate([batch_size, 1]))
else:
box_range_2d = _op.tile(box_range, _op.const([batch_size, 1]))

valid_mask = _op.cast(
_op.less(box_range_2d, _op.expand_dims(num_detections, axis=1)), "float32"
)
nmsed_boxes = nmsed_boxes * _op.expand_dims(valid_mask, axis=2)

return _expr.TupleWrapper(
Expand Down Expand Up @@ -877,19 +900,20 @@ def _impl(inputs, attr, params, mod):
q = boxes_shape[2]
num_classes = scores_shape[2]

if q == 1:
if q == 1 and isinstance(num_anchors, int):
boxes = _op.squeeze(boxes, axis=[2])
scores_trans = _op.transpose(scores, [0, 2, 1])
max_output_boxes_per_batch = num_anchors * num_classes
return all_class_impl(
batch_size,
max_output_boxes_per_batch,
boxes,
scores_trans,
max_output_size,
iou_threshold,
score_threshold,
max_total_size.data.numpy().item(),
attr["clip_boxes"],
mod,
)

boxes = _op.reshape(boxes, newshape=[batch_size, num_anchors * num_classes, 4])
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from tvm.contrib.thrust import can_use_thrust, can_use_rocthrust
from tvm.ir import register_intrin_lowering
from tvm.tir import if_then_else
from .sort import argsort, argsort_thrust, topk
from .sort import argsort, argsort_thrust
from ..broadcast import minimum
from .scan import exclusive_scan
from ..utils import ceil_div
from ..math import cast
Expand Down Expand Up @@ -1045,7 +1046,7 @@ def _collect_selected_indices_and_scores_ir(
)
collected_indices[batch_id, offset, 0] = zero
collected_indices[batch_id, offset, 1] = zero
collected_scores[batch_id, offset] = -1.0
collected_scores[batch_id, offset] = 0.0

return ib.get()

Expand Down Expand Up @@ -1133,4 +1134,6 @@ def all_class_non_max_suppression(
_collect_selected_indices_and_scores_ir,
)

num_total_detections = minimum(num_total_detections, max_total_size)

return [selected_indices, selected_scores, num_total_detections]
5 changes: 4 additions & 1 deletion python/tvm/topi/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ..sort import argsort
from ..math import cast
from ..transform import reshape, gather
from ..broadcast import minimum
from .. import reduction
from ..scan import cumsum
from .nms_util import (
Expand Down Expand Up @@ -772,7 +773,7 @@ def _collect_selected_indices_and_scores_ir(
)
collected_indices[batch_id, offset, 0] = zero
collected_indices[batch_id, offset, 1] = zero
collected_scores[batch_id, offset] = -1.0
collected_scores[batch_id, offset] = 0.0

return ib.get()

Expand Down Expand Up @@ -861,4 +862,6 @@ def all_class_non_max_suppression(
_collect_selected_indices_and_scores_ir,
)

num_total_detections = minimum(num_total_detections, max_total_size)

return [selected_indices, selected_scores, num_total_detections]

0 comments on commit c86bcf4

Please sign in to comment.