Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SSD512 with ResNet50 backbone #3760

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
2f1f578
Add experimental resnet50 backbone.
datumbox May 3, 2021
d46a302
Merge branch 'master' into models/ssd_resnet
datumbox May 3, 2021
0c17b0a
Passing custom scales (necessary after master merge).
datumbox May 3, 2021
b640680
Add experimental FPN-style resnet50 backbone.
datumbox May 3, 2021
36163dc
Add experimental VGG-style resnet50 backbone.
datumbox May 3, 2021
2c0f46d
Add a highres option to support both the 300 and 512 versions.
datumbox May 4, 2021
eef01bc
Select best performing prototype.
datumbox May 6, 2021
9cf7c5d
Adding documentation.
datumbox May 6, 2021
d419eea
Adding weights.
datumbox May 6, 2021
e6fb426
Merge branch 'master' into models/ssd_resnet
datumbox May 6, 2021
40da375
Merge branch 'master' into models/ssd_resnet
datumbox May 6, 2021
e66b800
Merge branch 'master' into models/ssd_resnet
datumbox May 7, 2021
e5472d0
Merge branch 'master' into models/ssd_resnet
datumbox May 7, 2021
e526e32
Merge branch 'master' into models/ssd_resnet
datumbox May 7, 2021
18495f3
Merge branch 'master' into models/ssd_resnet
datumbox May 10, 2021
14299da
Merge branch 'master' into models/ssd_resnet
datumbox May 11, 2021
ea1e2c4
Fix not implemented for half exception
datumbox May 11, 2021
04ec56a
Merge branch 'master' into models/ssd_resnet
datumbox May 11, 2021
777126d
Apply recommendations from code review.
datumbox May 11, 2021
87d0153
Updating docs.
datumbox May 11, 2021
8b2715d
Change the way we rescale to [-1, 1]
datumbox May 11, 2021
d08fc10
Change the way we rescale input on SSD300+VGG16
datumbox May 11, 2021
61ae292
Add comment.
datumbox May 11, 2021
18bf381
Merge branch 'master' into models/ssd_resnet
datumbox May 11, 2021
644bdcd
Merge branch 'master' into models/ssd_resnet
datumbox May 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ Faster R-CNN MobileNetV3-Large FPN 32.8 - -
Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - -
RetinaNet ResNet-50 FPN 36.4 - -
SSD300 VGG16 25.1 - -
SSD512 ResNet-50 30.2 - -
SSDlite320 MobileNetV3-Large 21.3 - -
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
====================================== ======= ======== ===========
Expand Down Expand Up @@ -491,6 +492,7 @@ Faster R-CNN MobileNetV3-Large FPN 0.1020 0.0415
Faster R-CNN MobileNetV3-Large 320 FPN 0.0978 0.0376 0.6
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
SSD300 VGG16 0.2093 0.0744 1.5
SSD512 ResNet-50 0.2316 0.0772 3.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the future: we need to change those tables as they are misleading for now -- the test time column for the SSD models is for a batch size of 4 per GPU, while for Faster R-CNN it was for a batch size 2.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe do a back of the envelop estimation to bring them to comparable batch-sizes?

SSDlite320 MobileNetV3-Large 0.1773 0.0906 1.5
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
Expand All @@ -515,6 +517,7 @@ SSD
---

.. autofunction:: torchvision.models.detection.ssd300_vgg16
.. autofunction:: torchvision.models.detection.ssd512_resnet50


SSDlite
Expand Down
8 changes: 8 additions & 0 deletions references/detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--weight-decay 0.0005 --data-augmentation ssd
```

### SSD512 ResNet-50
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--dataset coco --model ssd512_resnet50 --epochs 120\
--lr-steps 80 110 --aspect-ratio-group-factor 3 --lr 0.002 --batch-size 4\
--weight-decay 0.0005 --data-augmentation ssd
```

### SSDlite320 MobileNetV3-Large
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
Expand Down
Binary file not shown.
1 change: 1 addition & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def get_available_video_models():
"keypointrcnn_resnet50_fpn": lambda x: x[1],
"retinanet_resnet50_fpn": lambda x: x[1],
"ssd300_vgg16": lambda x: x[1],
"ssd512_resnet50": lambda x: x[1],
"ssdlite320_mobilenet_v3_large": lambda x: x[1],
}

Expand Down
138 changes: 136 additions & 2 deletions torchvision/models/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
from .anchor_utils import DefaultBoxGenerator
from .backbone_utils import _validate_trainable_layers
from .transform import GeneralizedRCNNTransform
from .. import vgg
from .. import vgg, resnet
from ..utils import load_state_dict_from_url
from ...ops import boxes as box_ops

__all__ = ['SSD', 'ssd300_vgg16']
__all__ = ['SSD', 'ssd300_vgg16', 'ssd512_resnet50']

model_urls = {
'ssd300_vgg16_coco': 'https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth',
'ssd512_resnet50_coco': 'https://download.pytorch.org/models/ssd512_resnet50_coco-d6d7edbb.pth',
}

backbone_urls = {
Expand Down Expand Up @@ -594,3 +595,136 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
model.load_state_dict(state_dict)
return model


class SSDFeatureExtractorResNet(nn.Module):
def __init__(self, backbone: resnet.ResNet):
super().__init__()

self.features = nn.Sequential(
backbone.conv1,
backbone.bn1,
backbone.relu,
backbone.maxpool,
backbone.layer1,
backbone.layer2,
backbone.layer3,
backbone.layer4,
)
Comment on lines +604 to +613
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any particular reason why you didn't use the IntermediateLayerGetter?

Also, for the future, I think we will want to unify the way we extract features so that we rely on the FX-based feature extractor, which will be more generic.


# Patch last block's strides to get valid output sizes
for m in self.features[-1][0].modules():
if hasattr(m, 'stride'):
m.stride = 1
Comment on lines +616 to +618
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't we want to instead to pass the dilation=[False, False, True] in the ResNet? Just replacing the stride from the last layer without adding dilation means that the features from the last block are not really acting the way they were initially trained to do.


backbone_out_channels = self.features[-1][-1].bn3.num_features
extra = nn.ModuleList([
nn.Sequential(
nn.Conv2d(backbone_out_channels, 256, kernel_size=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
Comment on lines +622 to +628
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit for the future: Might be good to refactor this in a Block class or something like that, which inherits from Sequential so that we keep the same names for the modules.
Something like

class ExtraBlock(nn.Sequential):
    def __init__(self, in_channels mid_channels, out_channels):
        super().__init__(nn.Conv2d(...), ...)

),
nn.Sequential(
nn.Conv2d(512, 256, kernel_size=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(512, 128, kernel_size=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(256, 128, kernel_size=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=3, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(256, 128, kernel_size=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=2, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
)
])
_xavier_init(extra)
self.extra = extra

def forward(self, x: Tensor) -> Dict[str, Tensor]:
x = self.features(x)
output = [x]

for block in self.extra:
x = block(x)
output.append(x)

return OrderedDict([(str(i), v) for i, v in enumerate(output)])


def _resnet_extractor(backbone_name: str, pretrained: bool, trainable_layers: int):
backbone = resnet.__dict__[backbone_name](pretrained=pretrained)

assert 0 <= trainable_layers <= 5
layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers]
if trainable_layers == 5:
layers_to_train.append('bn1')
for name, parameter in backbone.named_parameters():
if all([not name.startswith(layer) for layer in layers_to_train]):
parameter.requires_grad_(False)

return SSDFeatureExtractorResNet(backbone)


def ssd512_resnet50(pretrained: bool = False, progress: bool = True, num_classes: int = 91,
pretrained_backbone: bool = True, trainable_backbone_layers: Optional[int] = None, **kwargs: Any):
"""
Constructs an SSD model with input size 512x512 and a ResNet50 backbone. See `SSD` for more details.

Example:

>>> model = torchvision.models.detection.ssd512_resnet50(pretrained=True)
>>> model.eval()
>>> x = [torch.rand(3, 512, 512), torch.rand(3, 750, 600)]
>>> predictions = model(x)

Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
if "size" in kwargs:
warnings.warn("The size of the model is already fixed; ignoring the argument.")

trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 5)

if pretrained:
pretrained_backbone = False

backbone = _resnet_extractor("resnet50", pretrained_backbone, trainable_backbone_layers)
anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2]],
scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05])
model = SSD(backbone, anchor_generator, (512, 512), num_classes, **kwargs)
if pretrained:
weights_name = 'ssd512_resnet50_coco'
if model_urls.get(weights_name, None) is None:
raise ValueError("No checkpoint is available for model {}".format(weights_name))
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
model.load_state_dict(state_dict)
return model