Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add SSH Module #8953

Merged
merged 16 commits into from
Dec 15, 2022
Merged
3 changes: 2 additions & 1 deletion mmdet/models/necks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
from .pafpn import PAFPN
from .rfp import RFP
from .ssd_neck import SSDNeck
from .ssh import SSH
from .yolo_neck import YOLOV3Neck
from .yolox_pafpn import YOLOXPAFPN

__all__ = [
'FPN', 'BFP', 'ChannelMapper', 'HRFPN', 'NASFPN', 'FPN_CARAFE', 'PAFPN',
'NASFCOS_FPN', 'RFP', 'YOLOV3Neck', 'FPG', 'DilatedEncoder',
'CTResNetNeck', 'SSDNeck', 'YOLOXPAFPN', 'DyHead', 'CSPNeXtPAFPN'
'CTResNetNeck', 'SSDNeck', 'YOLOXPAFPN', 'DyHead', 'CSPNeXtPAFPN', 'SSH'
]
216 changes: 216 additions & 0 deletions mmdet/models/necks/ssh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
MambaWong marked this conversation as resolved.
Show resolved Hide resolved

import torch
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule

from mmdet.registry import MODELS
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig


class SSHContextModule(BaseModule):
"""This is an implementation of `SSH context module` described in `SSH:
Single Stage Headless Face Detector.

<https://arxiv.org/pdf/1708.03979.pdf>`_.

Args:
in_channels (int): Number of input channels used at each scale.
out_channels (int): Number of output channels used at each scale.
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
convolution layer. Defaults to None.
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
layer. Defaults to dict(type='BN').
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
list[dict], optional): Initialization config dict.
Defaults to None.
"""

def __init__(self,
in_channels: int,
out_channels: int,
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(type='BN'),
init_cfg: OptMultiConfig = None):
super().__init__(init_cfg=init_cfg)
assert out_channels % 4 == 0

self.in_channels = in_channels
self.out_channels = out_channels

self.conv5x5_1 = ConvModule(
self.in_channels,
self.out_channels // 4,
3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
)

self.conv5x5_2 = ConvModule(
self.out_channels // 4,
self.out_channels // 4,
3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)

self.conv7x7_2 = ConvModule(
self.out_channels // 4,
self.out_channels // 4,
3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
)

self.conv7x7_3 = ConvModule(
self.out_channels // 4,
self.out_channels // 4,
3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None,
)

def forward(self, x: torch.Tensor) -> tuple:
conv5x5_1 = self.conv5x5_1(x)
conv5x5 = self.conv5x5_2(conv5x5_1)
conv7x7_2 = self.conv7x7_2(conv5x5_1)
conv7x7 = self.conv7x7_3(conv7x7_2)

return (conv5x5, conv7x7)


class SSHDetModule(BaseModule):
"""This is an implementation of `SSH detection module` described in `SSH:
Single Stage Headless Face Detector.

<https://arxiv.org/pdf/1708.03979.pdf>`_.

Args:
in_channels (int): Number of input channels used at each scale.
out_channels (int): Number of output channels used at each scale.
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
convolution layer. Defaults to None.
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
layer. Defaults to dict(type='BN').
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
list[dict], optional): Initialization config dict.
Defaults to None.
"""

def __init__(self,
in_channels: int,
out_channels: int,
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(type='BN'),
init_cfg: OptMultiConfig = None):
super().__init__(init_cfg=init_cfg)
assert out_channels % 4 == 0

self.in_channels = in_channels
self.out_channels = out_channels

self.conv3x3 = ConvModule(
self.in_channels,
self.out_channels // 2,
3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)

self.context_module = SSHContextModule(
in_channels=self.in_channels,
out_channels=self.out_channels,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg)

def forward(self, x: torch.Tensor) -> torch.Tensor:
conv3x3 = self.conv3x3(x)
conv5x5, conv7x7 = self.context_module(x)
out = torch.cat([conv3x3, conv5x5, conv7x7], dim=1)
out = F.relu(out)

return out


@MODELS.register_module()
class SSH(BaseModule):
"""`SSH Neck` used in `SSH: Single Stage Headless Face Detector.

<https://arxiv.org/pdf/1708.03979.pdf>`_.

Args:
num_scales (int): The number of scales / stages.
in_channels (list[int]): The number of input channels per scale.
out_channels (list[int]): The number of output channels per scale.
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
convolution layer. Defaults to None.
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
layer. Defaults to dict(type='BN').
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
list[dict], optional): Initialization config dict.

Example:
>>> import torch
>>> in_channels = [8, 16, 32, 64]
>>> out_channels = [16, 32, 64, 128]
>>> scales = [340, 170, 84, 43]
>>> inputs = [torch.rand(1, c, s, s)
... for c, s in zip(in_channels, scales)]
>>> self = SSH(num_scales=4, in_channels=in_channels,
... out_channels=out_channels)
>>> outputs = self.forward(inputs)
>>> for i in range(len(outputs)):
... print(f'outputs[{i}].shape = {outputs[i].shape}')
outputs[0].shape = torch.Size([1, 16, 340, 340])
outputs[1].shape = torch.Size([1, 32, 170, 170])
outputs[2].shape = torch.Size([1, 64, 84, 84])
outputs[3].shape = torch.Size([1, 128, 43, 43])
"""

def __init__(self,
num_scales: int,
in_channels: List[int],
out_channels: List[int],
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(type='BN'),
BIGWangYuDong marked this conversation as resolved.
Show resolved Hide resolved
init_cfg: OptMultiConfig = dict(
type='Xavier', layer='Conv2d', distribution='uniform')):
super().__init__(init_cfg=init_cfg)
assert (num_scales == len(in_channels) == len(out_channels))
self.num_scales = num_scales
self.in_channels = in_channels
self.out_channels = out_channels

for idx in range(self.num_scales):
in_c, out_c = self.in_channels[idx], self.out_channels[idx]
self.add_module(
f'ssh_module{idx}',
SSHDetModule(
in_channels=in_c,
out_channels=out_c,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg))

def forward(self, inputs: Tuple[torch.Tensor]) -> tuple:
assert len(inputs) == self.num_scales

outs = []
for idx, x in enumerate(inputs):
ssh_module = getattr(self, f'ssh_module{idx}')
out = ssh_module(x)
outs.append(out)

return tuple(outs)
22 changes: 21 additions & 1 deletion tests/test_models/test_necks/test_necks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch.nn.modules.batchnorm import _BatchNorm

from mmdet.models.necks import (FPG, FPN, FPN_CARAFE, NASFCOS_FPN, NASFPN,
from mmdet.models.necks import (FPG, FPN, FPN_CARAFE, NASFCOS_FPN, NASFPN, SSH,
YOLOXPAFPN, ChannelMapper, DilatedEncoder,
DyHead, SSDNeck, YOLOV3Neck)

Expand Down Expand Up @@ -629,3 +629,23 @@ def test_nasfcos_fpn():
start_level=1,
end_level=2,
num_outs=3)


def test_ssh_neck():
"""Tests ssh."""
s = 64
in_channels = [8, 16, 32, 64]
feat_sizes = [s // 2**i for i in range(4)] # [64, 32, 16, 8]
out_channels = [16, 32, 64, 128]
ssh_model = SSH(
num_scales=4, in_channels=in_channels, out_channels=out_channels)

feats = [
torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
for i in range(len(in_channels))
]
outs = ssh_model(feats)
assert len(outs) == len(feats)
for i in range(len(outs)):
assert outs[i].shape == \
(1, out_channels[i], feat_sizes[i], feat_sizes[i])