Skip to content

Commit

Permalink
working on zero padding
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 29, 2021
1 parent 4567417 commit b020064
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
7 changes: 7 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,13 @@ def _impl(inputs, attr, params, mod):
nmsed_boxes = _op.maximum(nmsed_boxes, _expr.const(0, dtype="float32"))
nmsed_boxes = _op.minimum(nmsed_boxes, _expr.const(1, dtype="float32"))

# Fill in invalid entries with 0
box_range = _op.arange(_expr.const(0, dtype="int32"), max_total_size, dtype="int32")
box_range = _op.broadcast_to(_op.cast(box_range, "int64"), _op.shape_of(nmsed_scores))
valid_mask = _op.cast(_op.less(box_range, num_detections), "float32")
nmsed_scores = nmsed_scores * valid_mask
# nmsed_boxes = nmsed_boxes * _op.expand_dims(valid_mask, axis=2)

return _expr.TupleWrapper(
_expr.Tuple([nmsed_boxes, nmsed_scores, nmsed_classes, num_detections]), 4
)
Expand Down
15 changes: 11 additions & 4 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,17 +1021,24 @@ def _collect_selected_indices_tf_ir(
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
ib.scope_attr(by, "thread_extent", nthread_by)
zero = cast(0, "int64")

with ib.new_scope():
idx = bx * nthread_tx + tx
idy = cast(by, "int64")
batch_id = idy // num_class
class_id = idy % num_class
offset = row_offsets[batch_id, class_id] + idx

with ib.if_scope(idx < num_detections[batch_id, class_id]):
offset = row_offsets[batch_id, class_id]
collected_indices[batch_id, offset + idx, 0] = class_id
collected_indices[batch_id, offset + idx, 1] = cast(selected_indices[idy, idx], "int64")
collected_scores[batch_id, offset + idx] = selected_scores[idy, idx]
collected_indices[batch_id, offset, 0] = class_id
collected_indices[batch_id, offset, 1] = cast(selected_indices[idy, idx], "int64")
collected_scores[batch_id, offset] = selected_scores[idy, idx]
with ib.else_scope():
with ib.if_scope(idx < num_boxes):
collected_indices[batch_id, offset, 0] = zero
collected_indices[batch_id, offset, 1] = zero
collected_scores[batch_id, offset] = -1.0

return ib.get()

Expand Down

0 comments on commit b020064

Please sign in to comment.