From 388b1e96f8086f4a849b26a5040bcf7a0988c609 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 17 Nov 2020 12:22:21 +0000 Subject: [PATCH 1/4] Enable support for images without annotations --- .../test_models_detection_negative_samples.py | 9 ++++ torchvision/models/detection/retinanet.py | 49 ++++++++++--------- 2 files changed, 34 insertions(+), 24 deletions(-) diff --git a/test/test_models_detection_negative_samples.py b/test/test_models_detection_negative_samples.py index ed0cc515940..6d767971f72 100644 --- a/test/test_models_detection_negative_samples.py +++ b/test/test_models_detection_negative_samples.py @@ -128,6 +128,15 @@ def test_forward_negative_sample_krcnn(self): self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.)) self.assertEqual(loss_dict["loss_keypoint"], torch.tensor(0.)) + def test_forward_negative_sample_retinanet(self): + model = torchvision.models.detection.retinanet_resnet50_fpn( + num_classes=2, min_size=100, max_size=100) + + images, targets = self._make_empty_sample() + loss_dict = model(images, targets) + + self.assertEqual(loss_dict["bbox_regression"], torch.tensor(0.)) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index fc05106a807..ec90863df47 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -108,10 +108,9 @@ def compute_loss(self, targets, head_outputs, matched_idxs): foreground_idxs_per_image = matched_idxs_per_image >= 0 num_foreground = foreground_idxs_per_image.sum() # no matched_idxs means there were no annotations in this image - # TODO: enable support for images without annotations that works on distributed - if False: # matched_idxs_per_image.numel() == 0: + if matched_idxs_per_image.numel() == 0: gt_classes_target = torch.zeros_like(cls_logits_per_image) - valid_idxs_per_image = torch.arange(cls_logits_per_image.shape[0]) + valid_idxs_per_image = torch.arange(cls_logits_per_image.shape[0], device=cls_logits_per_image.device) else: # create the target classification gt_classes_target = torch.zeros_like(cls_logits_per_image) @@ -192,27 +191,29 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs): for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in \ zip(targets, bbox_regression, anchors, matched_idxs): # no matched_idxs means there were no annotations in this image - # TODO enable support for images without annotations with distributed support - # if matched_idxs_per_image.numel() == 0: - # continue - - # get the targets corresponding GT for each proposal - # NB: need to clamp the indices because we can have a single - # GT in the image, and matched_idxs can be -2, which goes - # out of bounds - matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image.clamp(min=0)] - - # determine only the foreground indices, ignore the rest - foreground_idxs_per_image = matched_idxs_per_image >= 0 - num_foreground = foreground_idxs_per_image.sum() - - # select only the foreground boxes - matched_gt_boxes_per_image = matched_gt_boxes_per_image[foreground_idxs_per_image, :] - bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :] - anchors_per_image = anchors_per_image[foreground_idxs_per_image, :] - - # compute the regression targets - target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) + if matched_idxs_per_image.numel() == 0: + device = targets_per_image['boxes'].device + bbox_regression_per_image = torch.zeros_like(targets_per_image['boxes'], device=device) + target_regression = torch.zeros_like(targets_per_image['boxes'], device=device) + num_foreground = torch.tensor(0, dtype=torch.int64, device=device) + else: + # get the targets corresponding GT for each proposal + # NB: need to clamp the indices because we can have a single + # GT in the image, and matched_idxs can be -2, which goes + # out of bounds + matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image.clamp(min=0)] + + # determine only the foreground indices, ignore the rest + foreground_idxs_per_image = matched_idxs_per_image >= 0 + num_foreground = foreground_idxs_per_image.sum() + + # select only the foreground boxes + matched_gt_boxes_per_image = matched_gt_boxes_per_image[foreground_idxs_per_image, :] + bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :] + anchors_per_image = anchors_per_image[foreground_idxs_per_image, :] + + # compute the regression targets + target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) # compute the loss losses.append(torch.nn.functional.l1_loss( From e8d5822c031f026f15a15c4686ccb5d61f8367c2 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 19 Nov 2020 15:21:24 +0000 Subject: [PATCH 2/4] Ensuring gradient propagates to RegressionHead. --- torchvision/models/detection/retinanet.py | 25 ++++++++++------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index ec90863df47..6f587728e2c 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -192,10 +192,7 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs): zip(targets, bbox_regression, anchors, matched_idxs): # no matched_idxs means there were no annotations in this image if matched_idxs_per_image.numel() == 0: - device = targets_per_image['boxes'].device - bbox_regression_per_image = torch.zeros_like(targets_per_image['boxes'], device=device) - target_regression = torch.zeros_like(targets_per_image['boxes'], device=device) - num_foreground = torch.tensor(0, dtype=torch.int64, device=device) + matched_gt_boxes_per_image = torch.zeros_like(bbox_regression_per_image) else: # get the targets corresponding GT for each proposal # NB: need to clamp the indices because we can have a single @@ -203,17 +200,17 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs): # out of bounds matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image.clamp(min=0)] - # determine only the foreground indices, ignore the rest - foreground_idxs_per_image = matched_idxs_per_image >= 0 - num_foreground = foreground_idxs_per_image.sum() + # determine only the foreground indices, ignore the rest + foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0] + num_foreground = foreground_idxs_per_image.numel() - # select only the foreground boxes - matched_gt_boxes_per_image = matched_gt_boxes_per_image[foreground_idxs_per_image, :] - bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :] - anchors_per_image = anchors_per_image[foreground_idxs_per_image, :] + # select only the foreground boxes + matched_gt_boxes_per_image = matched_gt_boxes_per_image[foreground_idxs_per_image, :] + bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :] + anchors_per_image = anchors_per_image[foreground_idxs_per_image, :] - # compute the regression targets - target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) + # compute the regression targets + target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) # compute the loss losses.append(torch.nn.functional.l1_loss( @@ -404,7 +401,7 @@ def compute_loss(self, targets, head_outputs, anchors): matched_idxs = [] for anchors_per_image, targets_per_image in zip(anchors, targets): if targets_per_image['boxes'].numel() == 0: - matched_idxs.append(torch.empty((0,), dtype=torch.int32)) + matched_idxs.append(torch.empty((0,), dtype=torch.int64)) continue match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image) From c5d28685ac001c71e1e43813883f2d2c8c645083 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 19 Nov 2020 16:10:04 +0000 Subject: [PATCH 3/4] Rewriting losses to remove branching. --- torchvision/models/detection/retinanet.py | 38 +++++++---------------- 1 file changed, 12 insertions(+), 26 deletions(-) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 6f587728e2c..d46d39543f8 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -107,20 +107,16 @@ def compute_loss(self, targets, head_outputs, matched_idxs): # determine only the foreground foreground_idxs_per_image = matched_idxs_per_image >= 0 num_foreground = foreground_idxs_per_image.sum() - # no matched_idxs means there were no annotations in this image - if matched_idxs_per_image.numel() == 0: - gt_classes_target = torch.zeros_like(cls_logits_per_image) - valid_idxs_per_image = torch.arange(cls_logits_per_image.shape[0], device=cls_logits_per_image.device) - else: - # create the target classification - gt_classes_target = torch.zeros_like(cls_logits_per_image) - gt_classes_target[ - foreground_idxs_per_image, - targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]] - ] = 1.0 - - # find indices for which anchors should be ignored - valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS + + # create the target classification + gt_classes_target = torch.zeros_like(cls_logits_per_image) + gt_classes_target[ + foreground_idxs_per_image, + targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]] + ] = 1.0 + + # find indices for which anchors should be ignored + valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS # compute the classification loss losses.append(sigmoid_focal_loss( @@ -190,22 +186,12 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs): for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in \ zip(targets, bbox_regression, anchors, matched_idxs): - # no matched_idxs means there were no annotations in this image - if matched_idxs_per_image.numel() == 0: - matched_gt_boxes_per_image = torch.zeros_like(bbox_regression_per_image) - else: - # get the targets corresponding GT for each proposal - # NB: need to clamp the indices because we can have a single - # GT in the image, and matched_idxs can be -2, which goes - # out of bounds - matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image.clamp(min=0)] - # determine only the foreground indices, ignore the rest foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0] num_foreground = foreground_idxs_per_image.numel() # select only the foreground boxes - matched_gt_boxes_per_image = matched_gt_boxes_per_image[foreground_idxs_per_image, :] + matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image[foreground_idxs_per_image]] bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :] anchors_per_image = anchors_per_image[foreground_idxs_per_image, :] @@ -401,7 +387,7 @@ def compute_loss(self, targets, head_outputs, anchors): matched_idxs = [] for anchors_per_image, targets_per_image in zip(anchors, targets): if targets_per_image['boxes'].numel() == 0: - matched_idxs.append(torch.empty((0,), dtype=torch.int64)) + matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64)) continue match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image) From 06ebee1a9f10c76d8ac5768fd578362dd5ace6e9 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 27 Nov 2020 17:18:29 +0000 Subject: [PATCH 4/4] Fix the seed on DeformConv autocast test. --- test/test_ops.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_ops.py b/test/test_ops.py index 1ba40d0da5f..68e6a5d2825 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,3 +1,4 @@ +from common_utils import set_rng_seed import math import unittest @@ -655,6 +656,7 @@ def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_): @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") def test_autocast(self): + set_rng_seed(0) for dtype in (torch.float, torch.half): with torch.cuda.amp.autocast(): self._test_forward(torch.device("cuda"), False, dtype=dtype)