Skip to content

Commit

Permalink
Add experimental VGG-style resnet50 backbone.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed May 3, 2021
1 parent b640680 commit 36163dc
Showing 1 changed file with 27 additions and 3 deletions.
30 changes: 27 additions & 3 deletions torchvision/models/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,12 +587,34 @@ def __init__(self, backbone: resnet.ResNet):
backbone_out_channels = self.features[-1][-1].bn3.num_features
extra = nn.ModuleList([
nn.Sequential(
nn.Conv2d(backbone_out_channels, 256, kernel_size=3, padding=1, stride=2, bias=False),
nn.Conv2d(backbone_out_channels, 256, kernel_size=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(512, 128, kernel_size=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2, bias=False),
nn.Conv2d(256, 128, kernel_size=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=3, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(256, 128, kernel_size=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=2, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
),
Expand Down Expand Up @@ -636,7 +658,9 @@ def ssd512_resnet50(pretrained: bool = False, progress: bool = True, num_classes
pretrained_backbone = False

backbone = _resnet_extractor("resnet50", pretrained_backbone, trainable_backbone_layers)
anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2]], min_ratio=0.04)
anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2], [2]],
scales=[0.04, 0.1, 0.26, 0.42, 0.58, 0.74, 0.9, 1.06],
steps=[8, 16, 32, 64, 128, 256, 512])
model = SSD(backbone, anchor_generator, (512, 512), num_classes, **kwargs)
if pretrained:
weights_name = 'ssd512_resnet50_coco'
Expand Down

0 comments on commit 36163dc

Please sign in to comment.