Skip to content

Commit

Permalink
fixing pylint
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 5, 2021
1 parent cd678fc commit da3eaf9
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 18 deletions.
5 changes: 1 addition & 4 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
def _nms_loop(
ib,
batch_size,
num_anchors,
top_k,
iou_threshold,
max_output_size,
Expand Down Expand Up @@ -589,7 +588,6 @@ def needs_bbox_check(i, j, k):
return _nms_loop(
ib,
batch_size,
num_anchors,
top_k,
iou_threshold,
max_output_size,
Expand Down Expand Up @@ -639,8 +637,7 @@ def _dispatch_sort(scores, ret_type="indices"):
or can_use_rocthrust(target, "tvm.contrib.thrust.sort")
):
return argsort_thrust(scores, axis=1, is_ascend=False, dtype="int32", ret_type=ret_type)
else:
return argsort(scores, axis=1, is_ascend=False, dtype="int32", ret_type=ret_type)
return argsort(scores, axis=1, is_ascend=False, dtype="int32", ret_type=ret_type)


def _get_sorted_indices(data, data_buf, score_index, score_shape):
Expand Down
12 changes: 4 additions & 8 deletions python/tvm/topi/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,12 @@
from tvm import te

from tvm.te import hybrid
from tvm.contrib import nvcc
from tvm.tir import if_then_else

from ..sort import argsort
from ..sort import sort, argsort
from ..math import cast
from ..utils import ceil_div
from ..transform import reshape
from ..reduction import sum
from ..sort import sort, argsort
from ..import reduction
from ..scan import cumsum
from .nms_util import (
binary_search,
Expand Down Expand Up @@ -616,7 +613,6 @@ def non_max_suppression(
def _nms_loop(
ib,
batch_size,
num_anchors,
top_k,
iou_threshold,
max_output_size,
Expand Down Expand Up @@ -709,7 +705,7 @@ def searchsorted_ir(scores, valid_count):


def _collect_selected_indices_ir(num_class, selected_indices, num_detections, row_offsets, out):
batch_classes, num_boxes = selected_indices.shape
batch_classes, _ = selected_indices.shape

ib = tvm.tir.ir_builder.create()

Expand Down Expand Up @@ -756,7 +752,7 @@ def all_class_non_max_suppression(

row_offsets = cumsum(num_detections, exclusive=True, dtype="int64")

num_total_detections = sum(cast(num_detections, "int64"), axis=1)
num_total_detections = reduction.sum(cast(num_detections, "int64"), axis=1)

selected_indices = collect_selected_indices(
num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir
Expand Down
16 changes: 10 additions & 6 deletions python/tvm/topi/vision/nms_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Common utilities used in Non-maximum suppression operators"""
import tvm
from tvm import te


def get_boundaries(output, box_idx):
def _get_boundaries(output, box_idx):
l = tvm.te.min(
output[box_idx],
output[box_idx + 2],
Expand All @@ -40,8 +42,8 @@ def get_boundaries(output, box_idx):

def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
"""Calculate overlap of two boxes."""
a_l, a_t, a_r, a_b = get_boundaries(out_tensor, box_a_idx)
b_l, b_t, b_r, b_b = get_boundaries(out_tensor, box_b_idx)
a_l, a_t, a_r, a_b = _get_boundaries(out_tensor, box_a_idx)
b_l, b_t, b_r, b_b = _get_boundaries(out_tensor, box_b_idx)

# Overlapping width and height
w = tvm.te.max(0.0, tvm.te.min(a_r, b_r) - tvm.te.max(a_l, b_l))
Expand All @@ -57,6 +59,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):


def binary_search(ib, y, num_boxes, scores, score_threshold, out):
"""Binary search for score_threshold on scores sorted in descending order"""
lo = ib.allocate("int32", (1,), name="lo", scope="local")
hi = ib.allocate("int32", (1,), name="hi", scope="local")

Expand All @@ -74,6 +77,7 @@ def binary_search(ib, y, num_boxes, scores, score_threshold, out):


def collect_selected_indices(num_class, selected_indices, num_detections, row_offsets, ir):
"""TODO"""
batch_class, num_boxes = selected_indices.shape

selected_indices_buf = tvm.tir.decl_buffer(
Expand Down Expand Up @@ -140,16 +144,15 @@ def on_new_valid_box(ib, tid, num_current_valid_box, i, j):
with ib.if_scope(tid + 0 == 0):
box_indices[i, num_current_valid_box] = sorted_indices[i, j]

def on_new_invalidated_box(i, k):
def on_new_invalidated_box(*_):
pass

def needs_bbox_check(i, j, k):
def needs_bbox_check(*_):
return tvm.tir.const(True)

return nms_loop(
ib,
batch_class,
num_anchors,
tvm.tir.IntImm("int32", -1), # top_k
iou_threshold,
max_output_size_per_class,
Expand All @@ -172,6 +175,7 @@ def run_all_class_nms(
iou_threshold,
nms_loop,
):
"""TODO"""
batch, num_boxes, _ = boxes.shape
batch_class = sorted_scores.shape[0]
num_class = batch_class // batch
Expand Down

0 comments on commit da3eaf9

Please sign in to comment.