diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index f44fe85bca0d..1f229998b2a0 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -372,8 +372,14 @@ def nms_ir( dimension are like the output of arange(num_anchors) if get_valid_counts is not used before non_max_suppression. - out : Buffer - Output buffer, to be filled with sorted boxes. + out_bboxes : Buffer + Output buffer, to be filled with sorted box coordinates. + + out_scores : Buffer + Output buffer, to be filled with sorted scores. + + out_class_ids : Buffer + Output buffer, to be filled with sorted class ids. box_indices : Buffer A indices tensor mapping sorted indices to original indices @@ -617,7 +623,8 @@ def nms_inner_loop(ib, j): # Proceed to the inner loop if the box j is still valid with ib.if_scope(out_scores[i, j] > -1.0): with ib.if_scope(max_output_size > 0): - # No need to do more iteration if we have already reached max_output_size boxes + # No need to do more iteration if we have already reached max_output_size + # boxes # TODO(masahi): Add TIR while loop to realize early exit from the outer loop with ib.if_scope(num_valid_boxes_local[0] < max_output_size): nms_inner_loop(ib, j) @@ -666,6 +673,7 @@ def _fetch_score_ir(data, score, axis): def _get_sorted_indices(data, data_buf, score_index, score_shape): + """Extract a 1D score tensor from the packed input and do argsort on it.""" score_buf = tvm.tir.decl_buffer(score_shape, data.dtype, "score_buf", data_alignment=8) score_tensor = te.extern( [score_shape], @@ -707,6 +715,7 @@ def _run_nms( score_index, return_indices, ): + """Run NMS using sorted scores.""" sort_tensor_buf = tvm.tir.decl_buffer( sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8 ) @@ -757,17 +766,18 @@ def _run_nms( def _concatenate_outputs( - out_bboxes, out_scores, out_sorted_ids, out_shape, coord_start, score_index, id_index + out_bboxes, out_scores, out_class_ids, out_shape, coord_start, score_index, id_index ): + """Pack the results from NMS into a single 5D or 6D tensor.""" batch_size = out_bboxes.shape[0] num_anchors = out_bboxes.shape[1] - def ir(out_bboxes, out_scores, out_sorted_ids, out): + def ir(out_bboxes, out_scores, out_class_ids, out): ib = tvm.tir.ir_builder.create() out_bboxes = ib.buffer_ptr(out_bboxes) out_scores = ib.buffer_ptr(out_scores) - out_sorted_ids = ib.buffer_ptr(out_sorted_ids) + out_class_ids = ib.buffer_ptr(out_class_ids) out = ib.buffer_ptr(out) with ib.if_scope(num_anchors > 0): @@ -789,19 +799,14 @@ def ir(out_bboxes, out_scores, out_sorted_ids, out): out[i, tid, coord_start + j] = out_bboxes[i, tid, j] out[i, tid, score_index] = out_scores[i, tid] if id_index >= 0: - out[i, tid, score_index] = out_scores[i, tid] + out[i, tid, id_index] = out_class_ids[i, tid] return ib.get() return te.extern( [out_shape], - [out_bboxes, out_scores, out_sorted_ids], - lambda ins, outs: ir( - ins[0], - ins[1], - ins[2], - outs[0], # sorted bbox - ), + [out_bboxes, out_scores, out_class_ids], + lambda ins, outs: ir(ins[0], ins[1], ins[2], outs[0]), dtype=["float32"], name="nms_output_concat", tag="nms_output_concat", @@ -903,7 +908,7 @@ def non_max_suppression( sort_tensor = _get_sorted_indices(data, data_buf, score_index, (data.shape[0], data.shape[1])) - out_bboxes, out_scores, out_sorted_ids, box_indices, num_valid_boxes = _run_nms( + out_bboxes, out_scores, out_class_ids, box_indices, num_valid_boxes = _run_nms( data, data_buf, sort_tensor, @@ -923,5 +928,5 @@ def non_max_suppression( return [box_indices, num_valid_boxes] return _concatenate_outputs( - out_bboxes, out_scores, out_sorted_ids, data.shape, coord_start, score_index, id_index + out_bboxes, out_scores, out_class_ids, data.shape, coord_start, score_index, id_index )