-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SSD512 with ResNet50 backbone #3760
Changes from all commits
2f1f578
d46a302
0c17b0a
b640680
36163dc
2c0f46d
eef01bc
9cf7c5d
d419eea
e6fb426
40da375
e66b800
e5472d0
e526e32
18495f3
14299da
ea1e2c4
04ec56a
777126d
87d0153
8b2715d
d08fc10
61ae292
18bf381
644bdcd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,14 +10,15 @@ | |
from .anchor_utils import DefaultBoxGenerator | ||
from .backbone_utils import _validate_trainable_layers | ||
from .transform import GeneralizedRCNNTransform | ||
from .. import vgg | ||
from .. import vgg, resnet | ||
from ..utils import load_state_dict_from_url | ||
from ...ops import boxes as box_ops | ||
|
||
__all__ = ['SSD', 'ssd300_vgg16'] | ||
__all__ = ['SSD', 'ssd300_vgg16', 'ssd512_resnet50'] | ||
|
||
model_urls = { | ||
'ssd300_vgg16_coco': 'https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth', | ||
'ssd512_resnet50_coco': 'https://download.pytorch.org/models/ssd512_resnet50_coco-d6d7edbb.pth', | ||
} | ||
|
||
backbone_urls = { | ||
|
@@ -594,3 +595,136 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i | |
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) | ||
model.load_state_dict(state_dict) | ||
return model | ||
|
||
|
||
class SSDFeatureExtractorResNet(nn.Module): | ||
def __init__(self, backbone: resnet.ResNet): | ||
super().__init__() | ||
|
||
self.features = nn.Sequential( | ||
backbone.conv1, | ||
backbone.bn1, | ||
backbone.relu, | ||
backbone.maxpool, | ||
backbone.layer1, | ||
backbone.layer2, | ||
backbone.layer3, | ||
backbone.layer4, | ||
) | ||
Comment on lines
+604
to
+613
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any particular reason why you didn't use the Also, for the future, I think we will want to unify the way we extract features so that we rely on the FX-based feature extractor, which will be more generic. |
||
|
||
# Patch last block's strides to get valid output sizes | ||
for m in self.features[-1][0].modules(): | ||
if hasattr(m, 'stride'): | ||
m.stride = 1 | ||
Comment on lines
+616
to
+618
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't we want to instead to pass the |
||
|
||
backbone_out_channels = self.features[-1][-1].bn3.num_features | ||
extra = nn.ModuleList([ | ||
nn.Sequential( | ||
nn.Conv2d(backbone_out_channels, 256, kernel_size=1, bias=False), | ||
nn.BatchNorm2d(256), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2, bias=False), | ||
nn.BatchNorm2d(512), | ||
nn.ReLU(inplace=True), | ||
Comment on lines
+622
to
+628
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit for the future: Might be good to refactor this in a
|
||
), | ||
nn.Sequential( | ||
nn.Conv2d(512, 256, kernel_size=1, bias=False), | ||
nn.BatchNorm2d(256), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2, bias=False), | ||
nn.BatchNorm2d(512), | ||
nn.ReLU(inplace=True), | ||
), | ||
nn.Sequential( | ||
nn.Conv2d(512, 128, kernel_size=1, bias=False), | ||
nn.BatchNorm2d(128), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2, bias=False), | ||
nn.BatchNorm2d(256), | ||
nn.ReLU(inplace=True), | ||
), | ||
nn.Sequential( | ||
nn.Conv2d(256, 128, kernel_size=1, bias=False), | ||
nn.BatchNorm2d(128), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(128, 256, kernel_size=3, bias=False), | ||
nn.BatchNorm2d(256), | ||
nn.ReLU(inplace=True), | ||
), | ||
nn.Sequential( | ||
nn.Conv2d(256, 128, kernel_size=1, bias=False), | ||
nn.BatchNorm2d(128), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(128, 256, kernel_size=2, bias=False), | ||
nn.BatchNorm2d(256), | ||
nn.ReLU(inplace=True), | ||
) | ||
]) | ||
_xavier_init(extra) | ||
self.extra = extra | ||
|
||
def forward(self, x: Tensor) -> Dict[str, Tensor]: | ||
x = self.features(x) | ||
output = [x] | ||
|
||
for block in self.extra: | ||
x = block(x) | ||
output.append(x) | ||
|
||
return OrderedDict([(str(i), v) for i, v in enumerate(output)]) | ||
|
||
|
||
def _resnet_extractor(backbone_name: str, pretrained: bool, trainable_layers: int): | ||
backbone = resnet.__dict__[backbone_name](pretrained=pretrained) | ||
|
||
assert 0 <= trainable_layers <= 5 | ||
layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers] | ||
if trainable_layers == 5: | ||
layers_to_train.append('bn1') | ||
for name, parameter in backbone.named_parameters(): | ||
if all([not name.startswith(layer) for layer in layers_to_train]): | ||
parameter.requires_grad_(False) | ||
|
||
return SSDFeatureExtractorResNet(backbone) | ||
|
||
|
||
def ssd512_resnet50(pretrained: bool = False, progress: bool = True, num_classes: int = 91, | ||
pretrained_backbone: bool = True, trainable_backbone_layers: Optional[int] = None, **kwargs: Any): | ||
""" | ||
Constructs an SSD model with input size 512x512 and a ResNet50 backbone. See `SSD` for more details. | ||
|
||
Example: | ||
|
||
>>> model = torchvision.models.detection.ssd512_resnet50(pretrained=True) | ||
>>> model.eval() | ||
>>> x = [torch.rand(3, 512, 512), torch.rand(3, 750, 600)] | ||
>>> predictions = model(x) | ||
|
||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on COCO train2017 | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
num_classes (int): number of output classes of the model (including the background) | ||
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet | ||
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. | ||
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. | ||
""" | ||
if "size" in kwargs: | ||
warnings.warn("The size of the model is already fixed; ignoring the argument.") | ||
|
||
trainable_backbone_layers = _validate_trainable_layers( | ||
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 5) | ||
|
||
if pretrained: | ||
pretrained_backbone = False | ||
|
||
backbone = _resnet_extractor("resnet50", pretrained_backbone, trainable_backbone_layers) | ||
anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2]], | ||
scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05]) | ||
model = SSD(backbone, anchor_generator, (512, 512), num_classes, **kwargs) | ||
if pretrained: | ||
weights_name = 'ssd512_resnet50_coco' | ||
if model_urls.get(weights_name, None) is None: | ||
raise ValueError("No checkpoint is available for model {}".format(weights_name)) | ||
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) | ||
model.load_state_dict(state_dict) | ||
return model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the future: we need to change those tables as they are misleading for now -- the
test time
column for the SSD models is for a batch size of 4 per GPU, while for Faster R-CNN it was for a batch size 2.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe do a back of the envelop estimation to bring them to comparable batch-sizes?