Skip to content

Commit

Permalink
pyformat and lint
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 6, 2021
1 parent ca7c056 commit 52b5bbd
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 29 deletions.
9 changes: 2 additions & 7 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,17 +1056,12 @@ def nms_strategy(attrs, inputs, out_type, target):

def wrap_compute_all_class_nms(topi_compute):
"""wrap nms topi compute"""

def _compute_nms(attrs, inputs, out_type):
max_output_size = inputs[2]
iou_threshold = inputs[3]
score_threshold = inputs[4]
return topi_compute(
inputs[0],
inputs[1],
max_output_size,
iou_threshold,
score_threshold
)
return topi_compute(inputs[0], inputs[1], max_output_size, iou_threshold, score_threshold)

return _compute_nms

Expand Down
5 changes: 1 addition & 4 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
def _nms_loop(
ib,
batch_size,
num_anchors,
top_k,
iou_threshold,
max_output_size,
Expand Down Expand Up @@ -589,7 +588,6 @@ def needs_bbox_check(i, j, k):
return _nms_loop(
ib,
batch_size,
num_anchors,
top_k,
iou_threshold,
max_output_size,
Expand Down Expand Up @@ -639,8 +637,7 @@ def _dispatch_sort(scores, ret_type="indices"):
or can_use_rocthrust(target, "tvm.contrib.thrust.sort")
):
return argsort_thrust(scores, axis=1, is_ascend=False, dtype="int32", ret_type=ret_type)
else:
return argsort(scores, axis=1, is_ascend=False, dtype="int32", ret_type=ret_type)
return argsort(scores, axis=1, is_ascend=False, dtype="int32", ret_type=ret_type)


def _get_sorted_indices(data, data_buf, score_index, score_shape):
Expand Down
12 changes: 4 additions & 8 deletions python/tvm/topi/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,12 @@
from tvm import te

from tvm.te import hybrid
from tvm.contrib import nvcc
from tvm.tir import if_then_else

from ..sort import argsort
from ..sort import sort, argsort
from ..math import cast
from ..utils import ceil_div
from ..transform import reshape
from ..reduction import sum
from ..sort import sort, argsort
from .. import reduction
from ..scan import cumsum
from .nms_util import (
binary_search,
Expand Down Expand Up @@ -616,7 +613,6 @@ def non_max_suppression(
def _nms_loop(
ib,
batch_size,
num_anchors,
top_k,
iou_threshold,
max_output_size,
Expand Down Expand Up @@ -709,7 +705,7 @@ def searchsorted_ir(scores, valid_count):


def _collect_selected_indices_ir(num_class, selected_indices, num_detections, row_offsets, out):
batch_classes, num_boxes = selected_indices.shape
batch_classes, _ = selected_indices.shape

ib = tvm.tir.ir_builder.create()

Expand Down Expand Up @@ -756,7 +752,7 @@ def all_class_non_max_suppression(

row_offsets = cumsum(num_detections, exclusive=True, dtype="int64")

num_total_detections = sum(cast(num_detections, "int64"), axis=1)
num_total_detections = reduction.sum(cast(num_detections, "int64"), axis=1)

selected_indices = collect_selected_indices(
num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir
Expand Down
28 changes: 19 additions & 9 deletions python/tvm/topi/vision/nms_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Common utilities used in Non-maximum suppression operators"""
import tvm
from tvm import te


def get_boundaries(output, box_idx):
def _get_boundaries(output, box_idx):
l = tvm.te.min(
output[box_idx],
output[box_idx + 2],
Expand All @@ -40,8 +42,8 @@ def get_boundaries(output, box_idx):

def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
"""Calculate overlap of two boxes."""
a_l, a_t, a_r, a_b = get_boundaries(out_tensor, box_a_idx)
b_l, b_t, b_r, b_b = get_boundaries(out_tensor, box_b_idx)
a_l, a_t, a_r, a_b = _get_boundaries(out_tensor, box_a_idx)
b_l, b_t, b_r, b_b = _get_boundaries(out_tensor, box_b_idx)

# Overlapping width and height
w = tvm.te.max(0.0, tvm.te.min(a_r, b_r) - tvm.te.max(a_l, b_l))
Expand All @@ -57,6 +59,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):


def binary_search(ib, y, num_boxes, scores, score_threshold, out):
"""Binary search for score_threshold on scores sorted in descending order"""
lo = ib.allocate("int32", (1,), name="lo", scope="local")
hi = ib.allocate("int32", (1,), name="hi", scope="local")

Expand All @@ -74,6 +77,7 @@ def binary_search(ib, y, num_boxes, scores, score_threshold, out):


def collect_selected_indices(num_class, selected_indices, num_detections, row_offsets, ir):
"""TODO"""
batch_class, num_boxes = selected_indices.shape

selected_indices_buf = tvm.tir.decl_buffer(
Expand Down Expand Up @@ -109,7 +113,7 @@ def _all_class_nms_ir(
max_output_size_per_class,
box_indices,
num_valid_boxes,
nms_loop
nms_loop,
):
ib = tvm.tir.ir_builder.create()
boxes = ib.buffer_ptr(boxes)
Expand Down Expand Up @@ -140,16 +144,15 @@ def on_new_valid_box(ib, tid, num_current_valid_box, i, j):
with ib.if_scope(tid + 0 == 0):
box_indices[i, num_current_valid_box] = sorted_indices[i, j]

def on_new_invalidated_box(i, k):
def on_new_invalidated_box(*_):
pass

def needs_bbox_check(i, j, k):
def needs_bbox_check(*_):
return tvm.tir.const(True)

return nms_loop(
ib,
batch_class,
num_anchors,
tvm.tir.IntImm("int32", -1), # top_k
iou_threshold,
max_output_size_per_class,
Expand All @@ -164,8 +167,15 @@ def needs_bbox_check(i, j, k):


def run_all_class_nms(
boxes, sorted_scores, sorted_indices, valid_count, max_output_size_per_class, iou_threshold, nms_loop
boxes,
sorted_scores,
sorted_indices,
valid_count,
max_output_size_per_class,
iou_threshold,
nms_loop,
):
"""TODO"""
batch, num_boxes, _ = boxes.shape
batch_class = sorted_scores.shape[0]
num_class = batch_class // batch
Expand Down Expand Up @@ -196,7 +206,7 @@ def run_all_class_nms(
max_output_size_per_class,
outs[0], # box_indices
outs[1], # num_valid_boxes
nms_loop
nms_loop,
),
dtype=["int32", "int32"],
in_buffers=[
Expand Down
2 changes: 1 addition & 1 deletion tests/python/topi/python/test_topi_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ def check_device(target):

f = tvm.build(s, [boxes, scores, out[0], out[1]], target)
f(tvm_boxes, tvm_scores, selected_indices, num_detections)
print(selected_indices.asnumpy()[:num_detections.asnumpy()[0]])
print(selected_indices.asnumpy()[: num_detections.asnumpy()[0]])
# tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4)

for target in ["llvm", "cuda"]:
Expand Down

0 comments on commit 52b5bbd

Please sign in to comment.