From c5d28685ac001c71e1e43813883f2d2c8c645083 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 19 Nov 2020 16:10:04 +0000 Subject: [PATCH] Rewriting losses to remove branching. --- torchvision/models/detection/retinanet.py | 38 +++++++---------------- 1 file changed, 12 insertions(+), 26 deletions(-) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 6f587728e2c..d46d39543f8 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -107,20 +107,16 @@ def compute_loss(self, targets, head_outputs, matched_idxs): # determine only the foreground foreground_idxs_per_image = matched_idxs_per_image >= 0 num_foreground = foreground_idxs_per_image.sum() - # no matched_idxs means there were no annotations in this image - if matched_idxs_per_image.numel() == 0: - gt_classes_target = torch.zeros_like(cls_logits_per_image) - valid_idxs_per_image = torch.arange(cls_logits_per_image.shape[0], device=cls_logits_per_image.device) - else: - # create the target classification - gt_classes_target = torch.zeros_like(cls_logits_per_image) - gt_classes_target[ - foreground_idxs_per_image, - targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]] - ] = 1.0 - - # find indices for which anchors should be ignored - valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS + + # create the target classification + gt_classes_target = torch.zeros_like(cls_logits_per_image) + gt_classes_target[ + foreground_idxs_per_image, + targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]] + ] = 1.0 + + # find indices for which anchors should be ignored + valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS # compute the classification loss losses.append(sigmoid_focal_loss( @@ -190,22 +186,12 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs): for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in \ zip(targets, bbox_regression, anchors, matched_idxs): - # no matched_idxs means there were no annotations in this image - if matched_idxs_per_image.numel() == 0: - matched_gt_boxes_per_image = torch.zeros_like(bbox_regression_per_image) - else: - # get the targets corresponding GT for each proposal - # NB: need to clamp the indices because we can have a single - # GT in the image, and matched_idxs can be -2, which goes - # out of bounds - matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image.clamp(min=0)] - # determine only the foreground indices, ignore the rest foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0] num_foreground = foreground_idxs_per_image.numel() # select only the foreground boxes - matched_gt_boxes_per_image = matched_gt_boxes_per_image[foreground_idxs_per_image, :] + matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image[foreground_idxs_per_image]] bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :] anchors_per_image = anchors_per_image[foreground_idxs_per_image, :] @@ -401,7 +387,7 @@ def compute_loss(self, targets, head_outputs, anchors): matched_idxs = [] for anchors_per_image, targets_per_image in zip(anchors, targets): if targets_per_image['boxes'].numel() == 0: - matched_idxs.append(torch.empty((0,), dtype=torch.int64)) + matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64)) continue match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image)