Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/layer freezing maskrcnn keypointrcnn #2242

Merged
merged 6 commits into from
May 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion test/test_models_detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torchvision.models.detection import _utils
from torchvision.models.detection.transform import GeneralizedRCNNTransform
import unittest
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection import fasterrcnn_resnet50_fpn, maskrcnn_resnet50_fpn, keypointrcnn_resnet50_fpn


class Tester(unittest.TestCase):
Expand Down Expand Up @@ -35,6 +35,36 @@ def test_fasterrcnn_resnet50_fpn_frozen_layers(self):
# check that expected initial number of layers are frozen
self.assertTrue(all(is_frozen[:exp_froz_params]))

def test_maskrcnn_resnet50_fpn_frozen_layers(self):
# we know how many initial layers and parameters of the maskrcnn should
# be frozen for each trainable_backbone_layers paramter value
# i.e all 53 params are frozen if trainable_backbone_layers=0
# ad first 24 params are frozen if trainable_backbone_layers=2
expected_frozen_params = {0: 53, 1: 43, 2: 24, 3: 11, 4: 1, 5: 0}
for train_layers, exp_froz_params in expected_frozen_params.items():
model = maskrcnn_resnet50_fpn(pretrained=True, progress=False,
num_classes=91, pretrained_backbone=False,
trainable_backbone_layers=train_layers)
# boolean list that is true if the parameter at that index is frozen
is_frozen = [not parameter.requires_grad for _, parameter in model.named_parameters()]
# check that expected initial number of layers in maskrcnn are frozen
self.assertTrue(all(is_frozen[:exp_froz_params]))

def test_keypointrcnn_resnet50_fpn_frozen_layers(self):
# we know how many initial layers and parameters of the keypointrcnn should
# be frozen for each trainable_backbone_layers paramter value
# i.e all 53 params are frozen if trainable_backbone_layers=0
# ad first 24 params are frozen if trainable_backbone_layers=2
expected_frozen_params = {0: 53, 1: 43, 2: 24, 3: 11, 4: 1, 5: 0}
for train_layers, exp_froz_params in expected_frozen_params.items():
model = keypointrcnn_resnet50_fpn(pretrained=True, progress=False,
num_classes=2, pretrained_backbone=False,
trainable_backbone_layers=train_layers)
# boolean list that is true if the parameter at that index is frozen
is_frozen = [not parameter.requires_grad for _, parameter in model.named_parameters()]
# check that expected initial number of layers in keypointrcnn are frozen
self.assertTrue(all(is_frozen[:exp_froz_params]))

def test_transform_copy_targets(self):
transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3))
image = [torch.rand(3, 200, 300), torch.rand(3, 200, 200)]
Expand Down
12 changes: 10 additions & 2 deletions torchvision/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def forward(self, x):

def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
num_classes=2, num_keypoints=17,
pretrained_backbone=True, **kwargs):
pretrained_backbone=True, trainable_backbone_layers=3, **kwargs):
"""
Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.

Expand Down Expand Up @@ -314,11 +314,19 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
Arguments:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
num_classes (int): number of output classes of the model (including the background)
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0
# dont freeze any layers if pretrained model or backbone is not used
if not (pretrained or pretrained_backbone):
trainable_backbone_layers = 5
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone)
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers)
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
if pretrained:
key = 'keypointrcnn_resnet50_fpn_coco'
Expand Down
12 changes: 10 additions & 2 deletions torchvision/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def __init__(self, in_channels, dim_reduced, num_classes):


def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
num_classes=91, pretrained_backbone=True, **kwargs):
num_classes=91, pretrained_backbone=True, trainable_backbone_layers=3, **kwargs):
"""
Constructs a Mask R-CNN model with a ResNet-50-FPN backbone.

Expand Down Expand Up @@ -310,11 +310,19 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
Arguments:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
num_classes (int): number of output classes of the model (including the background)
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0
# dont freeze any layers if pretrained model or backbone is not used
if not (pretrained or pretrained_backbone):
trainable_backbone_layers = 5
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone)
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers)
model = MaskRCNN(backbone, num_classes, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['maskrcnn_resnet50_fpn_coco'],
Expand Down