Skip to content

Commit

Permalink
Make OHEM work with seesaw loss (#6514)
Browse files Browse the repository at this point in the history
  • Loading branch information
ohwi authored Nov 19, 2021
1 parent a3258bc commit 25c8d75
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion mmdet/core/bbox/samplers/ohem_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self,
context,
neg_pos_ub=-1,
add_gt_as_proposals=True,
loss_key='loss_cls',
**kwargs):
super(OHEMSampler, self).__init__(num, pos_fraction, neg_pos_ub,
add_gt_as_proposals)
Expand All @@ -28,6 +29,8 @@ def __init__(self,
else:
self.bbox_head = self.context.bbox_head[self.context.current_stage]

self.loss_key = loss_key

def hard_mining(self, inds, num_expected, bboxes, labels, feats):
with torch.no_grad():
rois = bbox2roi([bboxes])
Expand All @@ -45,7 +48,7 @@ def hard_mining(self, inds, num_expected, bboxes, labels, feats):
label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None,
bbox_weights=None,
reduction_override='none')['loss_cls']
reduction_override='none')[self.loss_key]
_, topk_loss_inds = loss.topk(num_expected)
return inds[topk_loss_inds]

Expand Down

0 comments on commit 25c8d75

Please sign in to comment.