Skip to content

Commit

Permalink
fix converting score_threshold to Expr
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 8, 2021
1 parent 05fa415 commit f71e619
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
41 changes: 39 additions & 2 deletions python/tvm/relay/op/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,50 @@ def non_max_suppression(
def all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class=-1, iou_threshold=-1.0, score_threshold=-1.0
):
"""
TODO
"""Non-maximum suppression operator for object detection, corresponding to ONNX
NonMaxSuppression and TensorFlow combined_non_max_suppression.
NMS is performed for each class separately.
Parameters
----------
boxes : relay.Expr
3-D tensor with shape [batch_size, num_anchors, 6]
or [batch_size, num_anchors, 5].
The last dimension should be in format of
[class_id, score, box_left, box_top, box_right, box_bottom]
or [score, box_left, box_top, box_right, box_bottom]. It could
be the second output out_tensor of get_valid_counts.
scores: relay.Expr
2-D tensor with shape [batch_size, num_anchors], represents
the index of box in original data. It could be the third
output out_indices of get_valid_counts. The values in the
second dimension are like the output of arange(num_anchors)
if get_valid_counts is not used before non_max_suppression.
max_output_boxes_per_class : int or relay.Expr, optional
Max number of output valid boxes for each instance.
Return all valid boxes if the value of max_output_size is less than 0.
iou_threshold : float or relay.Expr, optionaIl
IoU test threshold
score_threshold : float or relay.Expr, optional
Score threshold to filter out low score boxes early
Returns
-------
out : relay.Tuple
The output is a relay.Tuple of two 2-D tensors, with
shape [batch_size, num_anchors] and [batch_size, num_valid_anchors] respectively.
"""
if not isinstance(max_output_boxes_per_class, expr.Expr):
max_output_boxes_per_class = expr.const(max_output_boxes_per_class, "int32")
if not isinstance(iou_threshold, expr.Expr):
iou_threshold = expr.const(iou_threshold, "float32")
if not isinstance(score_threshold, expr.Expr):
score_threshold = expr.const(score_threshold, "float32")

out = _make.all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold
)
Expand Down
4 changes: 1 addition & 3 deletions tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1371,9 +1371,7 @@ def verify_all_class_non_max_suppression(
score_threshold,
expected_indices,
):
dshape = boxes_np.shape

boxes = relay.var("boxes", relay.ty.TensorType(dshape, "float32"))
boxes = relay.var("boxes", relay.ty.TensorType(boxes_np.shape, "float32"))
scores = relay.var("scores", relay.ty.TensorType(scores_np.shape, "float32"))

out = relay.vision.all_class_non_max_suppression(
Expand Down

0 comments on commit f71e619

Please sign in to comment.