Skip to content

Commit

Permalink
Rewriting losses to remove branching.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Nov 19, 2020
1 parent e8d5822 commit c5d2868
Showing 1 changed file with 12 additions and 26 deletions.
38 changes: 12 additions & 26 deletions torchvision/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,16 @@ def compute_loss(self, targets, head_outputs, matched_idxs):
# determine only the foreground
foreground_idxs_per_image = matched_idxs_per_image >= 0
num_foreground = foreground_idxs_per_image.sum()
# no matched_idxs means there were no annotations in this image
if matched_idxs_per_image.numel() == 0:
gt_classes_target = torch.zeros_like(cls_logits_per_image)
valid_idxs_per_image = torch.arange(cls_logits_per_image.shape[0], device=cls_logits_per_image.device)
else:
# create the target classification
gt_classes_target = torch.zeros_like(cls_logits_per_image)
gt_classes_target[
foreground_idxs_per_image,
targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]]
] = 1.0

# find indices for which anchors should be ignored
valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS

# create the target classification
gt_classes_target = torch.zeros_like(cls_logits_per_image)
gt_classes_target[
foreground_idxs_per_image,
targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]]
] = 1.0

# find indices for which anchors should be ignored
valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS

# compute the classification loss
losses.append(sigmoid_focal_loss(
Expand Down Expand Up @@ -190,22 +186,12 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs):

for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in \
zip(targets, bbox_regression, anchors, matched_idxs):
# no matched_idxs means there were no annotations in this image
if matched_idxs_per_image.numel() == 0:
matched_gt_boxes_per_image = torch.zeros_like(bbox_regression_per_image)
else:
# get the targets corresponding GT for each proposal
# NB: need to clamp the indices because we can have a single
# GT in the image, and matched_idxs can be -2, which goes
# out of bounds
matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image.clamp(min=0)]

# determine only the foreground indices, ignore the rest
foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
num_foreground = foreground_idxs_per_image.numel()

# select only the foreground boxes
matched_gt_boxes_per_image = matched_gt_boxes_per_image[foreground_idxs_per_image, :]
matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image[foreground_idxs_per_image]]
bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]

Expand Down Expand Up @@ -401,7 +387,7 @@ def compute_loss(self, targets, head_outputs, anchors):
matched_idxs = []
for anchors_per_image, targets_per_image in zip(anchors, targets):
if targets_per_image['boxes'].numel() == 0:
matched_idxs.append(torch.empty((0,), dtype=torch.int64))
matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64))
continue

match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image)
Expand Down

0 comments on commit c5d2868

Please sign in to comment.