diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 925d108e5abd..2280cff3059b 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -21,7 +21,7 @@ from tvm import te from tvm.tir import if_then_else -from .sort import argsort, argsort_thrust +from .sort import argsort, argsort_thrust, is_thrust_available def cuda_atomic_add_rule(op): @@ -338,7 +338,9 @@ def nms_ir( sorted_index, valid_count, indices, - out, + out_bboxes, + out_scores, + out_class_ids, box_indices, num_valid_boxes, max_output_size, @@ -458,9 +460,13 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): sorted_index = ib.buffer_ptr(sorted_index) valid_count = ib.buffer_ptr(valid_count) indices = ib.buffer_ptr(indices) - num_valid_boxes = ib.buffer_ptr(num_valid_boxes) - out = ib.buffer_ptr(out) + + # outputs + out_bboxes = ib.buffer_ptr(out_bboxes) + out_scores = ib.buffer_ptr(out_scores) + out_class_ids = ib.buffer_ptr(out_class_ids) box_indices = ib.buffer_ptr(box_indices) + num_valid_boxes = ib.buffer_ptr(num_valid_boxes) if isinstance(iou_threshold, float): iou_threshold = tvm.tir.FloatImm("float32", iou_threshold) @@ -483,7 +489,9 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) i = by - base_idx = i * num_anchors * box_data_length + base_src_idx = i * num_anchors * box_data_length + base_bbox_idx = i * num_anchors * 4 + with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): # Reorder output nkeep = if_then_else( @@ -491,18 +499,28 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): ) j = bx * max_threads + tx with ib.if_scope(j < nkeep): + src_idx = base_src_idx + sorted_index[i * num_anchors + j] * box_data_length # Fill in out with sorted boxes - with ib.for_range(0, box_data_length) as k: - out[(base_idx + j * box_data_length + k)] = data[ - (base_idx + sorted_index[i * num_anchors + j] * box_data_length + k) - ] + with ib.for_range(0, 4) as k: + out_bboxes[(base_bbox_idx + j * 4 + k)] = data[src_idx + coord_start + k] + + out_scores[i * num_anchors + j] = data[src_idx + score_index] + + if id_index >= 0: + out_class_ids[i * num_anchors + j] = data[src_idx + id_index] + with ib.else_scope(): # Indices > nkeep are discarded # Only needed for return_indices = False case if return_indices is False: with ib.if_scope(j < num_anchors): - with ib.for_range(0, box_data_length) as k: - out[(base_idx + j * box_data_length + k)] = -1.0 + with ib.for_range(0, 4) as k: + out_bboxes[(base_bbox_idx + j * 4 + k)] = -1.0 + + out_scores[i, j] = -1.0 + + if id_index >= 0: + out_class_ids[i, j] = -1.0 if return_indices: with ib.if_scope(j < num_anchors): @@ -510,9 +528,16 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): with ib.else_scope(): with ib.if_scope(j < valid_count[i]): - with ib.for_range(0, box_data_length) as k: - offset = base_idx + j * box_data_length + k - out[offset] = data[offset] + src_offset = base_src_idx + j * box_data_length + + with ib.for_range(0, 4) as k: + out_bboxes[base_bbox_idx + j * 4 + k] = data[src_offset + coord_start + k] + + out_scores[i * num_anchors + j] = data[src_offset + score_index] + + if id_index >= 0: + out_class_ids[i * num_anchors + j] = data[src_offset + id_index] + box_indices[i * num_anchors + j] = j with ib.new_scope(): @@ -526,7 +551,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): i = by - base_idx = i * num_anchors * box_data_length + base_bbox_idx = i * num_anchors * 4 num_valid_boxes_local = ib.allocate( "int32", (1,), name="num_valid_boxes_local", scope="local" ) @@ -549,37 +574,35 @@ def nms_inner_loop(ib, j): num_valid_boxes_local[0] += 1 - offset_j = j * box_data_length + offset_j = j * 4 num_iter_per_thread = ceil_div(valid_count[i] - (j + 1), nthread_tx) with ib.for_range(0, num_iter_per_thread) as _k: k = j + 1 + _k * nthread_tx + tx - offset_k = k * box_data_length + offset_k = k * 4 with ib.if_scope( tvm.tir.all( k < num_anchors, - out[base_idx + offset_k + score_index] > 0, # is the box k still valid? + out_scores[i, k] > 0, # is the box k still valid? tvm.tir.any( force_suppress > 0, id_index < 0, - out[base_idx + offset_k + id_index] - == out[base_idx + offset_j + id_index], + out_class_ids[i, k] == out_class_ids[i, j], ), ) ): iou = calculate_overlap( - out, - base_idx + offset_j + coord_start, - base_idx + offset_k + coord_start, + out_bboxes, + base_bbox_idx + offset_j, + base_bbox_idx + offset_k, ) with ib.if_scope(iou >= iou_threshold): # invalidate the box k - out[base_idx + offset_k + score_index] = -1.0 - with ib.if_scope(id_index >= 0): - out[base_idx + offset_k + id_index] = -1.0 + out_scores[i, k] = -1.0 + if return_indices is False and id_index >= 0: + out_class_ids[i, k] = -1.0 - # Make sure to do the next loop in a lock step ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) if isinstance(max_output_size, int): @@ -589,7 +612,7 @@ def nms_inner_loop(ib, j): # Apply nms with ib.for_range(0, valid_count[i]) as j: # Proceed to the inner loop if the box j is still valid - with ib.if_scope(out[base_idx + (j * box_data_length) + score_index] > -1.0): + with ib.if_scope(out_scores[i, j] > -1.0): with ib.if_scope(max_output_size > 0): # No need to do more iteration if we already reach max_output_size boxes with ib.if_scope(num_valid_boxes_local[0] < max_output_size): @@ -638,6 +661,33 @@ def _fetch_score_ir(data, score, axis): return ib.get() +def _get_sorted_indices(data, data_buf, score_index, score_shape): + score_buf = tvm.tir.decl_buffer(score_shape, data.dtype, "score_buf", data_alignment=8) + score_tensor = te.extern( + [score_shape], + [data], + lambda ins, outs: _fetch_score_ir( + ins[0], + outs[0], + score_index, + ), + dtype=[data.dtype], + in_buffers=[data_buf], + out_buffers=[score_buf], + name="fetch_score", + tag="fetch_score", + ) + + if is_thrust_available(): + sort_tensor = argsort_thrust( + score_tensor, valid_count=None, axis=1, is_ascend=False, dtype="int32" + ) + else: + sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype="int32") + + return sort_tensor + + def non_max_suppression( data, valid_count, @@ -736,54 +786,35 @@ def non_max_suppression( valid_count_buf = tvm.tir.decl_buffer( valid_count.shape, valid_count_dtype, "valid_count_buf", data_alignment=4 ) - score_axis = score_index + score_shape = (batch_size, num_anchors) data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) - score_buf = tvm.tir.decl_buffer(score_shape, data.dtype, "score_buf", data_alignment=8) - score_tensor = te.extern( - [score_shape], - [data], - lambda ins, outs: _fetch_score_ir( - ins[0], - outs[0], - score_axis, - ), - dtype=[data.dtype], - in_buffers=[data_buf], - out_buffers=[score_buf], - name="fetch_score", - tag="fetch_score", - ) - target = tvm.target.Target.current() - if ( - target - and target.kind.name == "cuda" - and tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True) - ): - sort_tensor = argsort_thrust( - score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype - ) - else: - sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype=valid_count_dtype) + sort_tensor = _get_sorted_indices(data, data_buf, score_index, score_shape) sort_tensor_buf = tvm.tir.decl_buffer( sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8 ) - data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8) - out, box_indices, num_valid_boxes = te.extern( - [data.shape, score_shape, [batch_size, 1]], + bbox_shape = (batch_size, num_anchors, 4) + class_id_shape = score_shape + box_indices_shape = score_shape + num_valid_boxes_shape = (batch_size, 1) + + out_bboxes, out_scores, out_sorted_ids, box_indices, num_valid_boxes = te.extern( + [bbox_shape, score_shape, class_id_shape, box_indices_shape, num_valid_boxes_shape], [data, sort_tensor, valid_count, indices], lambda ins, outs: nms_ir( ins[0], ins[1], ins[2], ins[3], - outs[0], - outs[1], - outs[2], + outs[0], # sorted bbox + outs[1], # sorted scores + outs[2], # sorted class ids + outs[3], # box_indices + outs[4], # num_valid_boxes max_output_size, iou_threshold, force_suppress, @@ -793,7 +824,7 @@ def non_max_suppression( score_index, return_indices, ), - dtype=[data.dtype, "int32", "int32"], + dtype=[data.dtype, "float32", "float32", "int32", "int32"], in_buffers=[data_buf, sort_tensor_buf, valid_count_buf, indices_buf], name="nms", tag="nms", @@ -802,4 +833,9 @@ def non_max_suppression( if return_indices: return [box_indices, num_valid_boxes] - return out + # TODO: do concat + return out_bboxes + # if id_index >= 0: + # return concatenate([out_bboxes + + # return out