Skip to content

Commit

Permalink
remove score filtering in rtmdet_head rewriter since it leads to erro…
Browse files Browse the repository at this point in the history
…r shape in batch inference (#1762)
  • Loading branch information
lvhan028 authored Feb 15, 2023
1 parent 8cd3048 commit 73fd14d
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions mmdeploy/codebase/mmdet/models/dense_heads/rtmdet_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,7 @@ def __mark_pred_maps(cls_scores, bbox_preds):
br_x = (priors[..., 0] + flatten_bbox_preds[..., 2])
br_y = (priors[..., 1] + flatten_bbox_preds[..., 3])
bboxes = torch.stack([tl_x, tl_y, br_x, br_y], -1)
# directly multiply score factor and feed to nms
max_scores, _ = torch.max(flatten_cls_scores, 1)
mask = max_scores >= cfg.score_thr
scores = flatten_cls_scores.where(mask, flatten_cls_scores.new_zeros(1))
scores = flatten_cls_scores
if not with_nms:
return bboxes, scores

Expand Down

0 comments on commit 73fd14d

Please sign in to comment.