Skip to content

Commit

Permalink
pyformat
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 5, 2021
1 parent d4daa1a commit cd678fc
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
9 changes: 2 additions & 7 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,17 +1056,12 @@ def nms_strategy(attrs, inputs, out_type, target):

def wrap_compute_all_class_nms(topi_compute):
"""wrap nms topi compute"""

def _compute_nms(attrs, inputs, out_type):
max_output_size = inputs[2]
iou_threshold = inputs[3]
score_threshold = inputs[4]
return topi_compute(
inputs[0],
inputs[1],
max_output_size,
iou_threshold,
score_threshold
)
return topi_compute(inputs[0], inputs[1], max_output_size, iou_threshold, score_threshold)

return _compute_nms

Expand Down
12 changes: 9 additions & 3 deletions python/tvm/topi/vision/nms_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _all_class_nms_ir(
max_output_size_per_class,
box_indices,
num_valid_boxes,
nms_loop
nms_loop,
):
ib = tvm.tir.ir_builder.create()
boxes = ib.buffer_ptr(boxes)
Expand Down Expand Up @@ -164,7 +164,13 @@ def needs_bbox_check(i, j, k):


def run_all_class_nms(
boxes, sorted_scores, sorted_indices, valid_count, max_output_size_per_class, iou_threshold, nms_loop
boxes,
sorted_scores,
sorted_indices,
valid_count,
max_output_size_per_class,
iou_threshold,
nms_loop,
):
batch, num_boxes, _ = boxes.shape
batch_class = sorted_scores.shape[0]
Expand Down Expand Up @@ -196,7 +202,7 @@ def run_all_class_nms(
max_output_size_per_class,
outs[0], # box_indices
outs[1], # num_valid_boxes
nms_loop
nms_loop,
),
dtype=["int32", "int32"],
in_buffers=[
Expand Down
2 changes: 1 addition & 1 deletion tests/python/topi/python/test_topi_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ def check_device(target):

f = tvm.build(s, [boxes, scores, out[0], out[1]], target)
f(tvm_boxes, tvm_scores, selected_indices, num_detections)
print(selected_indices.asnumpy()[:num_detections.asnumpy()[0]])
print(selected_indices.asnumpy()[: num_detections.asnumpy()[0]])
# tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4)

for target in ["llvm", "cuda"]:
Expand Down

0 comments on commit cd678fc

Please sign in to comment.