Skip to content

Commit

Permalink
minor change
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 29, 2021
1 parent f88e2a3 commit fc3a38e
Showing 1 changed file with 24 additions and 18 deletions.
42 changes: 24 additions & 18 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,24 +832,30 @@ def all_class_impl(
_op.less(box_range_2d, _op.expand_dims(num_detections, axis=1)), "float32"
)

# TODO: support dynamic num_boxes
if max_output_boxes_per_batch < max_total_size:
arange = _op.arange(
_op.const(0, dtype="int64"),
_op.const(max_output_boxes_per_batch, dtype="int64"),
dtype="int64",
)
pad = _op.full(
_op.const(0, dtype="int64"), (max_total_size - max_output_boxes_per_batch,)
)
topk_indices = _op.tile(_op.concatenate([arange, pad], 0), tile_batch_reps)
nmsed_scores = _op.gather(selected_scores, 1, topk_indices)
nmsed_scores = nmsed_scores * valid_mask
else:
nmsed_scores, topk_indices = _op.topk(
selected_scores, k=max_total_size, axis=1, ret_type="both"
)
def select_topk(do_zero_pad):
def true_branch():
arange = _op.arange(
_op.const(0, dtype="int64"),
_op.const(max_output_boxes_per_batch, dtype="int64"),
dtype="int64",
)
pad = _op.full(
_op.const(0, dtype="int64"), (max_total_size - max_output_boxes_per_batch,)
)
topk_indices = _op.tile(_op.concatenate([arange, pad], 0), tile_batch_reps)
nmsed_scores = _op.gather(selected_scores, 1, topk_indices)
nmsed_scores = nmsed_scores * valid_mask
return nmsed_scores, topk_indices

def false_branch():
return _op.topk(selected_scores, k=max_total_size, axis=1, ret_type="both")

# TODO(masahi): support dynamic num_boxes
# return _expr.If(do_zero_pad, true_branch(), false_branch())
return true_branch() if do_zero_pad else false_branch()

assert isinstance(max_output_boxes_per_batch, int),"dynamic number of boxes not supported yet."
nmsed_scores, topk_indices = select_topk(max_output_boxes_per_batch < max_total_size)
topk_indices = _op.expand_dims(topk_indices, axis=0)
indices = _op.gather_nd(selected_indices, topk_indices, batch_dims=1)

Expand Down Expand Up @@ -894,7 +900,7 @@ def _impl(inputs, attr, params, mod):
q = boxes_shape[2]
num_classes = scores_shape[2]

if q == 1:
if q == 1 and isinstance(num_anchors, int):
boxes = _op.squeeze(boxes, axis=[2])
scores_trans = _op.transpose(scores, [0, 2, 1])
max_output_boxes_per_batch = num_anchors * num_classes
Expand Down

0 comments on commit fc3a38e

Please sign in to comment.