diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 5ab1f5e91764..ae00a9a77139 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -832,24 +832,30 @@ def all_class_impl( _op.less(box_range_2d, _op.expand_dims(num_detections, axis=1)), "float32" ) - # TODO: support dynamic num_boxes - if max_output_boxes_per_batch < max_total_size: - arange = _op.arange( - _op.const(0, dtype="int64"), - _op.const(max_output_boxes_per_batch, dtype="int64"), - dtype="int64", - ) - pad = _op.full( - _op.const(0, dtype="int64"), (max_total_size - max_output_boxes_per_batch,) - ) - topk_indices = _op.tile(_op.concatenate([arange, pad], 0), tile_batch_reps) - nmsed_scores = _op.gather(selected_scores, 1, topk_indices) - nmsed_scores = nmsed_scores * valid_mask - else: - nmsed_scores, topk_indices = _op.topk( - selected_scores, k=max_total_size, axis=1, ret_type="both" - ) + def select_topk(do_zero_pad): + def true_branch(): + arange = _op.arange( + _op.const(0, dtype="int64"), + _op.const(max_output_boxes_per_batch, dtype="int64"), + dtype="int64", + ) + pad = _op.full( + _op.const(0, dtype="int64"), (max_total_size - max_output_boxes_per_batch,) + ) + topk_indices = _op.tile(_op.concatenate([arange, pad], 0), tile_batch_reps) + nmsed_scores = _op.gather(selected_scores, 1, topk_indices) + nmsed_scores = nmsed_scores * valid_mask + return nmsed_scores, topk_indices + + def false_branch(): + return _op.topk(selected_scores, k=max_total_size, axis=1, ret_type="both") + + # TODO(masahi): support dynamic num_boxes + # return _expr.If(do_zero_pad, true_branch(), false_branch()) + return true_branch() if do_zero_pad else false_branch() + assert isinstance(max_output_boxes_per_batch, int),"dynamic number of boxes not supported yet." + nmsed_scores, topk_indices = select_topk(max_output_boxes_per_batch < max_total_size) topk_indices = _op.expand_dims(topk_indices, axis=0) indices = _op.gather_nd(selected_indices, topk_indices, batch_dims=1) @@ -894,7 +900,7 @@ def _impl(inputs, attr, params, mod): q = boxes_shape[2] num_classes = scores_shape[2] - if q == 1: + if q == 1 and isinstance(num_anchors, int): boxes = _op.squeeze(boxes, axis=[2]) scores_trans = _op.transpose(scores, [0, 2, 1]) max_output_boxes_per_batch = num_anchors * num_classes