Skip to content

Commit

Permalink
[fix]: fix softnms (#1019)
Browse files Browse the repository at this point in the history
* fix basemodule

* fix typo

* fix unitest
  • Loading branch information
jshilong authored May 13, 2021
1 parent da07114 commit 9b8dd08
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
9 changes: 6 additions & 3 deletions mmcv/ops/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,18 +282,21 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
# This assumes `dets` has 5 dimensions where
# the last dimension is score.
# TODO: more elegant way to handle the dimension issue.
# Some type of nms would reweight the score, such as SoftNMS
scores = dets[:, 4]
else:
total_mask = scores.new_zeros(scores.size(), dtype=torch.bool)
# Some type of nms would reweight the score, such as SoftNMS
scores_after_nms = scores.new_zeros(scores.size())
for id in torch.unique(idxs):
mask = (idxs == id).nonzero(as_tuple=False).view(-1)
dets, keep = nms_op(boxes_for_nms[mask], scores[mask], **nms_cfg_)
total_mask[mask[keep]] = True

scores_after_nms[mask[keep]] = dets[:, -1]
keep = total_mask.nonzero(as_tuple=False).view(-1)
keep = keep[scores[keep].argsort(descending=True)]
scores, inds = scores_after_nms[keep].sort(descending=True)
keep = keep[inds]
boxes = boxes[keep]
scores = scores[keep]

return torch.cat([boxes, scores[:, None]], -1), keep

Expand Down
19 changes: 19 additions & 0 deletions tests/test_ops/test_nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,22 @@ def test_batched_nms(self):
assert torch.equal(keep, seq_keep)
assert torch.equal(boxes, seq_boxes)
assert torch.equal(keep, torch.from_numpy(results['keep']))

nms_cfg = dict(type='soft_nms', iou_threshold=0.7)
boxes, keep = batched_nms(
torch.from_numpy(results['boxes']),
torch.from_numpy(results['scores']),
torch.from_numpy(results['idxs']),
nms_cfg,
class_agnostic=False)

nms_cfg.update(split_thr=100)
seq_boxes, seq_keep = batched_nms(
torch.from_numpy(results['boxes']),
torch.from_numpy(results['scores']),
torch.from_numpy(results['idxs']),
nms_cfg,
class_agnostic=False)

assert torch.equal(keep, seq_keep)
assert torch.equal(boxes, seq_boxes)

0 comments on commit 9b8dd08

Please sign in to comment.