Skip to content

Commit

Permalink
Add a highres option to support both the 300 and 512 versions.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed May 4, 2021
1 parent 36163dc commit 2c0f46d
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions torchvision/models/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i


class SSDFeatureExtractorResNet(nn.Module):
def __init__(self, backbone: resnet.ResNet):
def __init__(self, backbone: resnet.ResNet, highres: bool):
super().__init__()

self.features = nn.Sequential(
Expand Down Expand Up @@ -610,15 +610,16 @@ def __init__(self, backbone: resnet.ResNet):
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
),
nn.Sequential(
])
if highres:
extra.append(nn.Sequential(
nn.Conv2d(256, 128, kernel_size=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=2, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
),
])
))
_xavier_init(extra)
self.extra = extra

Expand All @@ -635,8 +636,8 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:
return OrderedDict([(str(i), v) for i, v in enumerate(output)])


def _resnet_extractor(backbone_name: str, pretrained: bool, trainable_layers: int):
backbone = resnet.__dict__[backbone_name](pretrained=pretrained)
def _resnet_extractor(backbone_name: str, highres: bool, progress: bool, pretrained: bool, trainable_layers: int):
backbone = resnet.__dict__[backbone_name](pretrained=pretrained, progress=progress)

assert 0 <= trainable_layers <= 5
layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers]
Expand All @@ -646,7 +647,7 @@ def _resnet_extractor(backbone_name: str, pretrained: bool, trainable_layers: in
if all([not name.startswith(layer) for layer in layers_to_train]):
parameter.requires_grad_(False)

return SSDFeatureExtractorResNet(backbone)
return SSDFeatureExtractorResNet(backbone, highres)


def ssd512_resnet50(pretrained: bool = False, progress: bool = True, num_classes: int = 91,
Expand All @@ -657,7 +658,7 @@ def ssd512_resnet50(pretrained: bool = False, progress: bool = True, num_classes
if pretrained:
pretrained_backbone = False

backbone = _resnet_extractor("resnet50", pretrained_backbone, trainable_backbone_layers)
backbone = _resnet_extractor("resnet50", True, progress, pretrained_backbone, trainable_backbone_layers)
anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2], [2]],
scales=[0.04, 0.1, 0.26, 0.42, 0.58, 0.74, 0.9, 1.06],
steps=[8, 16, 32, 64, 128, 256, 512])
Expand Down

0 comments on commit 2c0f46d

Please sign in to comment.