Skip to content

Commit

Permalink
Add experimental resnet50 backbone.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed May 3, 2021
1 parent 730c5e1 commit 2f1f578
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 2 deletions.
Binary file not shown.
1 change: 1 addition & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def get_available_video_models():
"keypointrcnn_resnet50_fpn": lambda x: x[1],
"retinanet_resnet50_fpn": lambda x: x[1],
"ssd300_vgg16": lambda x: x[1],
"ssd512_resnet50": lambda x: x[1],
}


Expand Down
116 changes: 114 additions & 2 deletions torchvision/models/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
from .anchor_utils import DefaultBoxGenerator
from .backbone_utils import _validate_trainable_layers
from .transform import GeneralizedRCNNTransform
from .. import vgg
from .. import vgg, resnet
from ..utils import load_state_dict_from_url
from ...ops import boxes as box_ops

__all__ = ['SSD', 'ssd300_vgg16']
__all__ = ['SSD', 'ssd300_vgg16', 'ssd512_resnet50']

model_urls = {
'ssd300_vgg16_coco': 'https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth',
'ssd512_resnet50_coco': None, # TODO: add weights
}

backbone_urls = {
Expand Down Expand Up @@ -562,3 +563,114 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
model.load_state_dict(state_dict)
return model


class SSDFeatureExtractorResNet(nn.Module):
def __init__(self, backbone: resnet.ResNet):
super().__init__()

self.features = nn.Sequential(
backbone.conv1,
backbone.bn1,
backbone.relu,
backbone.maxpool,
backbone.layer1,
backbone.layer2,
backbone.layer3,
backbone.layer4,
)

# Patch last block's strides to get valid output sizes
for m in self.features[-1][0].modules():
if hasattr(m, 'stride'):
m.stride = 1

backbone_out_channels = self.features[-1][-1].bn3.num_features
extra = nn.ModuleList([
nn.Sequential(
nn.Conv2d(backbone_out_channels, 256, kernel_size=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(512, 256, kernel_size=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(512, 128, kernel_size=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(256, 128, kernel_size=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=3, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(256, 128, kernel_size=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=2, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
)
])
_xavier_init(extra)
self.extra = extra

def forward(self, x: Tensor) -> Dict[str, Tensor]:
x = self.features(x)
output = [x]

for block in self.extra:
x = block(x)
output.append(x)

return OrderedDict([(str(i), v) for i, v in enumerate(output)])


def _resnet_extractor(backbone_name: str, pretrained: bool, trainable_layers: int):
backbone = resnet.__dict__[backbone_name](pretrained=pretrained)

assert 0 <= trainable_layers <= 5
layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers]
if trainable_layers == 5:
layers_to_train.append('bn1')
for name, parameter in backbone.named_parameters():
if all([not name.startswith(layer) for layer in layers_to_train]):
parameter.requires_grad_(False)

return SSDFeatureExtractorResNet(backbone)


def ssd512_resnet50(pretrained: bool = False, progress: bool = True, num_classes: int = 91,
pretrained_backbone: bool = True, trainable_backbone_layers: Optional[int] = None, **kwargs: Any):
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 5)

if pretrained:
pretrained_backbone = False

backbone = _resnet_extractor("resnet50", pretrained_backbone, trainable_backbone_layers)
anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2]])
model = SSD(backbone, anchor_generator, (512, 512), num_classes, **kwargs)
if pretrained:
weights_name = 'ssd512_resnet50_coco'
if model_urls.get(weights_name, None) is None:
raise ValueError("No checkpoint is available for model {}".format(weights_name))
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
model.load_state_dict(state_dict)
return model

0 comments on commit 2f1f578

Please sign in to comment.