diff --git a/mmdet/core/bbox/assigners/sim_ota_assigner.py b/mmdet/core/bbox/assigners/sim_ota_assigner.py index 9e25ea09d19..5a5902970af 100644 --- a/mmdet/core/bbox/assigners/sim_ota_assigner.py +++ b/mmdet/core/bbox/assigners/sim_ota_assigner.py @@ -127,7 +127,13 @@ def _assign(self, assigned_gt_inds = decoded_bboxes.new_full((num_bboxes, ), 0, dtype=torch.long) - if num_gt == 0 or num_bboxes == 0: + valid_mask, is_in_boxes_and_center = self.get_in_gt_and_in_center_info( + priors, gt_bboxes) + valid_decoded_bbox = decoded_bboxes[valid_mask] + valid_pred_scores = pred_scores[valid_mask] + num_valid = valid_decoded_bbox.size(0) + + if num_gt == 0 or num_bboxes == 0 or num_valid == 0: # No ground truth or boxes, return empty assignment max_overlaps = decoded_bboxes.new_zeros((num_bboxes, )) if num_gt == 0: @@ -142,13 +148,6 @@ def _assign(self, return AssignResult( num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) - valid_mask, is_in_boxes_and_center = self.get_in_gt_and_in_center_info( - priors, gt_bboxes) - - valid_decoded_bbox = decoded_bboxes[valid_mask] - valid_pred_scores = pred_scores[valid_mask] - num_valid = valid_decoded_bbox.size(0) - pairwise_ious = bbox_overlaps(valid_decoded_bbox, gt_bboxes) iou_cost = -torch.log(pairwise_ious + eps)