diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index af99447514d03..bc2b94840f470 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -857,6 +857,13 @@ def _impl(inputs, attr, params, mod): nmsed_boxes = _op.maximum(nmsed_boxes, _expr.const(0, dtype="float32")) nmsed_boxes = _op.minimum(nmsed_boxes, _expr.const(1, dtype="float32")) + # Fill in invalid entries with 0 + box_range = _op.arange(_expr.const(0, dtype="int32"), max_total_size, dtype="int32") + box_range = _op.broadcast_to(_op.cast(box_range, "int64"), _op.shape_of(nmsed_scores)) + valid_mask = _op.cast(_op.less(box_range, num_detections), "float32") + nmsed_scores = nmsed_scores * valid_mask + # nmsed_boxes = nmsed_boxes * _op.expand_dims(valid_mask, axis=2) + return _expr.TupleWrapper( _expr.Tuple([nmsed_boxes, nmsed_scores, nmsed_classes, num_detections]), 4 ) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 971464ce26f2f..cc21d5ebc0669 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -1021,17 +1021,24 @@ def _collect_selected_indices_tf_ir( ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) ib.scope_attr(by, "thread_extent", nthread_by) + zero = cast(0, "int64") with ib.new_scope(): idx = bx * nthread_tx + tx idy = cast(by, "int64") batch_id = idy // num_class class_id = idy % num_class + offset = row_offsets[batch_id, class_id] + idx + with ib.if_scope(idx < num_detections[batch_id, class_id]): - offset = row_offsets[batch_id, class_id] - collected_indices[batch_id, offset + idx, 0] = class_id - collected_indices[batch_id, offset + idx, 1] = cast(selected_indices[idy, idx], "int64") - collected_scores[batch_id, offset + idx] = selected_scores[idy, idx] + collected_indices[batch_id, offset, 0] = class_id + collected_indices[batch_id, offset, 1] = cast(selected_indices[idy, idx], "int64") + collected_scores[batch_id, offset] = selected_scores[idy, idx] + with ib.else_scope(): + with ib.if_scope(idx < num_boxes): + collected_indices[batch_id, offset, 0] = zero + collected_indices[batch_id, offset, 1] = zero + collected_scores[batch_id, offset] = -1.0 return ib.get()