From f88e2a3a98a7ee283622e57712e28634374e5e2c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 30 May 2021 07:14:54 +0900 Subject: [PATCH] minor refactor --- python/tvm/relay/frontend/tensorflow.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index fa591d5e31c7..5ab1f5e91764 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -796,8 +796,7 @@ def _impl(inputs, attr, params, mod): def _combined_nms(): def all_class_impl( batch_size, - num_boxes, - num_classes, + max_output_boxes_per_batch, boxes, scores, max_output_boxes_per_class, @@ -805,7 +804,6 @@ def all_class_impl( score_threshold, max_total_size, clip_boxes, - mod, ): ( selected_indices, @@ -835,14 +833,15 @@ def all_class_impl( ) # TODO: support dynamic num_boxes - max_output_boxes = num_boxes * num_classes - if max_output_boxes < max_total_size: + if max_output_boxes_per_batch < max_total_size: arange = _op.arange( _op.const(0, dtype="int64"), - _op.const(max_output_boxes, dtype="int64"), + _op.const(max_output_boxes_per_batch, dtype="int64"), dtype="int64", ) - pad = _op.full(_op.const(0, dtype="int64"), (max_total_size - max_output_boxes,)) + pad = _op.full( + _op.const(0, dtype="int64"), (max_total_size - max_output_boxes_per_batch,) + ) topk_indices = _op.tile(_op.concatenate([arange, pad], 0), tile_batch_reps) nmsed_scores = _op.gather(selected_scores, 1, topk_indices) nmsed_scores = nmsed_scores * valid_mask @@ -898,10 +897,10 @@ def _impl(inputs, attr, params, mod): if q == 1: boxes = _op.squeeze(boxes, axis=[2]) scores_trans = _op.transpose(scores, [0, 2, 1]) + max_output_boxes_per_batch = num_anchors * num_classes return all_class_impl( batch_size, - num_anchors, - num_classes, + max_output_boxes_per_batch, boxes, scores_trans, max_output_size, @@ -909,7 +908,6 @@ def _impl(inputs, attr, params, mod): score_threshold, max_total_size.data.numpy().item(), attr["clip_boxes"], - mod, ) boxes = _op.reshape(boxes, newshape=[batch_size, num_anchors * num_classes, 4])