diff --git a/test/test_models_detection_utils.py b/test/test_models_detection_utils.py index a9cb5a98c17..9abef94ce9f 100644 --- a/test/test_models_detection_utils.py +++ b/test/test_models_detection_utils.py @@ -3,7 +3,7 @@ from torchvision.models.detection import _utils from torchvision.models.detection.transform import GeneralizedRCNNTransform import unittest -from torchvision.models.detection import fasterrcnn_resnet50_fpn +from torchvision.models.detection import fasterrcnn_resnet50_fpn, maskrcnn_resnet50_fpn, keypointrcnn_resnet50_fpn class Tester(unittest.TestCase): @@ -35,6 +35,36 @@ def test_fasterrcnn_resnet50_fpn_frozen_layers(self): # check that expected initial number of layers are frozen self.assertTrue(all(is_frozen[:exp_froz_params])) + def test_maskrcnn_resnet50_fpn_frozen_layers(self): + # we know how many initial layers and parameters of the maskrcnn should + # be frozen for each trainable_backbone_layers paramter value + # i.e all 53 params are frozen if trainable_backbone_layers=0 + # ad first 24 params are frozen if trainable_backbone_layers=2 + expected_frozen_params = {0: 53, 1: 43, 2: 24, 3: 11, 4: 1, 5: 0} + for train_layers, exp_froz_params in expected_frozen_params.items(): + model = maskrcnn_resnet50_fpn(pretrained=True, progress=False, + num_classes=91, pretrained_backbone=False, + trainable_backbone_layers=train_layers) + # boolean list that is true if the parameter at that index is frozen + is_frozen = [not parameter.requires_grad for _, parameter in model.named_parameters()] + # check that expected initial number of layers in maskrcnn are frozen + self.assertTrue(all(is_frozen[:exp_froz_params])) + + def test_keypointrcnn_resnet50_fpn_frozen_layers(self): + # we know how many initial layers and parameters of the keypointrcnn should + # be frozen for each trainable_backbone_layers paramter value + # i.e all 53 params are frozen if trainable_backbone_layers=0 + # ad first 24 params are frozen if trainable_backbone_layers=2 + expected_frozen_params = {0: 53, 1: 43, 2: 24, 3: 11, 4: 1, 5: 0} + for train_layers, exp_froz_params in expected_frozen_params.items(): + model = keypointrcnn_resnet50_fpn(pretrained=True, progress=False, + num_classes=2, pretrained_backbone=False, + trainable_backbone_layers=train_layers) + # boolean list that is true if the parameter at that index is frozen + is_frozen = [not parameter.requires_grad for _, parameter in model.named_parameters()] + # check that expected initial number of layers in keypointrcnn are frozen + self.assertTrue(all(is_frozen[:exp_froz_params])) + def test_transform_copy_targets(self): transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3)) image = [torch.rand(3, 200, 300), torch.rand(3, 200, 200)] diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 257932fff9a..fadf0cf60f0 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -270,7 +270,7 @@ def forward(self, x): def keypointrcnn_resnet50_fpn(pretrained=False, progress=True, num_classes=2, num_keypoints=17, - pretrained_backbone=True, **kwargs): + pretrained_backbone=True, trainable_backbone_layers=3, **kwargs): """ Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone. @@ -314,11 +314,19 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True, Arguments: pretrained (bool): If True, returns a model pre-trained on COCO train2017 progress (bool): If True, displays a progress bar of the download to stderr + pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet + num_classes (int): number of output classes of the model (including the background) + trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ + assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0 + # dont freeze any layers if pretrained model or backbone is not used + if not (pretrained or pretrained_backbone): + trainable_backbone_layers = 5 if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False - backbone = resnet_fpn_backbone('resnet50', pretrained_backbone) + backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers) model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) if pretrained: key = 'keypointrcnn_resnet50_fpn_coco' diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index a8a980fa3ce..32c63ca4cf1 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -265,7 +265,7 @@ def __init__(self, in_channels, dim_reduced, num_classes): def maskrcnn_resnet50_fpn(pretrained=False, progress=True, - num_classes=91, pretrained_backbone=True, **kwargs): + num_classes=91, pretrained_backbone=True, trainable_backbone_layers=3, **kwargs): """ Constructs a Mask R-CNN model with a ResNet-50-FPN backbone. @@ -310,11 +310,19 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True, Arguments: pretrained (bool): If True, returns a model pre-trained on COCO train2017 progress (bool): If True, displays a progress bar of the download to stderr + pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet + num_classes (int): number of output classes of the model (including the background) + trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ + assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0 + # dont freeze any layers if pretrained model or backbone is not used + if not (pretrained or pretrained_backbone): + trainable_backbone_layers = 5 if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False - backbone = resnet_fpn_backbone('resnet50', pretrained_backbone) + backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers) model = MaskRCNN(backbone, num_classes, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls['maskrcnn_resnet50_fpn_coco'],