Skip to content

Commit

Permalink
relay type inference works, debugging topi
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 29, 2021
1 parent 4a4b8df commit 7f5c76d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
15 changes: 11 additions & 4 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,19 +828,26 @@ def _impl(inputs, attr, params, mod):
# Transpose (batch_size, num_boxes, num_classes) -> (batch_size, num_classes, num_boxes)
scores_trans = _op.transpose(scores, [0, 2, 1])

print(max_output_boxes_per_class)
print(iou_threshold)
print(score_threshold)

indices, num_detections = _op.vision.all_class_non_max_suppression(
boxes,
scores_trans,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
max_total_size,
max_total_size.data.numpy().item(),
output_format="tensorflow",
)

nmsed_box_indices = _op.take(indices, 0, axis=2)
nmsed_classes = _op.take(indices, 1, axis=2)
nmsed_boxes = _op.gather_nd(boxes, nmsed_box_indices, batch_dims=1)
nmsed_box_indices = _op.take(indices, _op.const(0), axis=2)
nmsed_classes = _op.take(indices, _op.const(1), axis=2)
nmsed_boxes = _op.gather_nd(boxes, _op.expand_dims(nmsed_box_indices, axis=0), batch_dims=1)

indices_dims = len(_infer_shape(indices, mod))
indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1)))
nmsed_scores = _op.gather_nd(scores, indices, batch_dims=1)

if attr["clip_boxes"]:
Expand Down
9 changes: 5 additions & 4 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ..math import cast
from .. import reduction
from ..broadcast import minimum
from ..transform import reshape, strided_slice, gather_nd
from ..transform import reshape, strided_slice, gather_nd, expand_dims
from ..vision.nms_util import (
calculate_overlap,
binary_search,
Expand Down Expand Up @@ -1005,6 +1005,7 @@ def _collect_selected_indices_tf_ir(
ib = tvm.tir.ir_builder.create()

selected_indices = ib.buffer_ptr(selected_indices)
selected_scores = ib.buffer_ptr(selected_scores)
num_detections = ib.buffer_ptr(num_detections)
row_offsets = ib.buffer_ptr(row_offsets)
collected_indices = ib.buffer_ptr(collected_indices)
Expand Down Expand Up @@ -1043,7 +1044,7 @@ def collect_selected_indices_tf(selected_indices, selected_scores, num_detection
selected_indices.shape, selected_indices.dtype, "selected_indices_buf", data_alignment=8
)
selected_scores_buf = tvm.tir.decl_buffer(
selected_scores.shape, selected_indices.dtype, "selected_scores_buf", data_alignment=8
selected_scores.shape, selected_scores.dtype, "selected_scores_buf", data_alignment=8
)
num_detections_buf = tvm.tir.decl_buffer(
num_detections.shape, num_detections.dtype, "num_detections_buf", data_alignment=8
Expand All @@ -1058,7 +1059,7 @@ def collect_selected_indices_tf(selected_indices, selected_scores, num_detection
lambda ins, outs: _collect_selected_indices_tf_ir(
num_class, ins[0], ins[1], ins[2], ins[3], outs[0], outs[1]
),
dtype=["int64"],
dtype=["int64", "float32"],
in_buffers=[selected_indices_buf, selected_scores_buf, num_detections_buf, row_offsets_buf],
name="collect_indices",
tag="collect_indices",
Expand Down Expand Up @@ -1152,6 +1153,6 @@ def all_class_non_max_suppression(
selected_scores, begin=[0, 0], end=[batch, reduction.max(num_total_detections)]
)
topk_indices = topk(selected_scores, k=max_detection_per_batch, axis=1, ret_type="indices")
final_indices = gather_nd(selected_indices, topk_indices, batch_dims=1)
final_indices = gather_nd(selected_indices, expand_dims(topk_indices, axis=0), batch_dims=1)
num_detections = minimum(num_total_detections, max_detection_per_batch)
return [final_indices, num_detections]
4 changes: 2 additions & 2 deletions python/tvm/topi/vision/nms_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def run_all_class_nms(
valid_count.shape, "int32", "valid_count_buf", data_alignment=4
)

if return_scores:
if return_scores is False:
return te.extern(
[(batch_class, num_boxes), (1, batch_class)],
[boxes, sorted_scores, sorted_indices, valid_count],
Expand Down Expand Up @@ -362,7 +362,7 @@ def run_all_class_nms(
outs[2], # num_selected_boxes
nms_loop,
),
dtype=["int32", "int32"],
dtype=["int32", "float32", "int32"],
in_buffers=[
boxes_buf,
sorted_scores_buf,
Expand Down

0 comments on commit 7f5c76d

Please sign in to comment.