Skip to content

Commit

Permalink
[Feat]: add SSH Module (#8953)
Browse files Browse the repository at this point in the history
* add SSH Module

* Update mmdet/models/necks/ssh.py

Co-authored-by: BigDong <yudongwang1226@gmail.com>

* Update mmdet/models/necks/ssh.py

Co-authored-by: BigDong <yudongwang1226@gmail.com>

* Update mmdet/models/necks/ssh.py

Co-authored-by: BigDong <yudongwang1226@gmail.com>

* Update mmdet/models/necks/ssh.py

Co-authored-by: BigDong <yudongwang1226@gmail.com>

* Update mmdet/models/necks/ssh.py

Co-authored-by: BigDong <yudongwang1226@gmail.com>

* Update mmdet/models/necks/ssh.py

Co-authored-by: BigDong <yudongwang1226@gmail.com>

* Update mmdet/models/necks/ssh.py

Co-authored-by: BigDong <yudongwang1226@gmail.com>

* Update mmdet/models/necks/ssh.py

* Update mmdet/models/necks/ssh.py

* fix variable names

* decompose the SSHModule into SSHContextModule and SSHDetModule

* remove RetinaSSH

* add ssh unit test

* add `init_cfg` parameter to `SSHContextModule`

* add ssh example

Co-authored-by: BigDong <yudongwang1226@gmail.com>
  • Loading branch information
MambaWong and BIGWangYuDong authored Dec 15, 2022
1 parent 380d936 commit 28513d5
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 2 deletions.
3 changes: 2 additions & 1 deletion mmdet/models/necks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
from .pafpn import PAFPN
from .rfp import RFP
from .ssd_neck import SSDNeck
from .ssh import SSH
from .yolo_neck import YOLOV3Neck
from .yolox_pafpn import YOLOXPAFPN

__all__ = [
'FPN', 'BFP', 'ChannelMapper', 'HRFPN', 'NASFPN', 'FPN_CARAFE', 'PAFPN',
'NASFCOS_FPN', 'RFP', 'YOLOV3Neck', 'FPG', 'DilatedEncoder',
'CTResNetNeck', 'SSDNeck', 'YOLOXPAFPN', 'DyHead', 'CSPNeXtPAFPN'
'CTResNetNeck', 'SSDNeck', 'YOLOXPAFPN', 'DyHead', 'CSPNeXtPAFPN', 'SSH'
]
216 changes: 216 additions & 0 deletions mmdet/models/necks/ssh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple

import torch
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule

from mmdet.registry import MODELS
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig


class SSHContextModule(BaseModule):
"""This is an implementation of `SSH context module` described in `SSH:
Single Stage Headless Face Detector.
<https://arxiv.org/pdf/1708.03979.pdf>`_.
Args:
in_channels (int): Number of input channels used at each scale.
out_channels (int): Number of output channels used at each scale.
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
convolution layer. Defaults to None.
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
layer. Defaults to dict(type='BN').
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
list[dict], optional): Initialization config dict.
Defaults to None.
"""

def __init__(self,
in_channels: int,
out_channels: int,
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(type='BN'),
init_cfg: OptMultiConfig = None):
super().__init__(init_cfg=init_cfg)
assert out_channels % 4 == 0

self.in_channels = in_channels
self.out_channels = out_channels

self.conv5x5_1 = ConvModule(
self.in_channels,
self.out_channels // 4,
3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
)

self.conv5x5_2 = ConvModule(
self.out_channels // 4,
self.out_channels // 4,
3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)

self.conv7x7_2 = ConvModule(
self.out_channels // 4,
self.out_channels // 4,
3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
)

self.conv7x7_3 = ConvModule(
self.out_channels // 4,
self.out_channels // 4,
3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None,
)

def forward(self, x: torch.Tensor) -> tuple:
conv5x5_1 = self.conv5x5_1(x)
conv5x5 = self.conv5x5_2(conv5x5_1)
conv7x7_2 = self.conv7x7_2(conv5x5_1)
conv7x7 = self.conv7x7_3(conv7x7_2)

return (conv5x5, conv7x7)


class SSHDetModule(BaseModule):
"""This is an implementation of `SSH detection module` described in `SSH:
Single Stage Headless Face Detector.
<https://arxiv.org/pdf/1708.03979.pdf>`_.
Args:
in_channels (int): Number of input channels used at each scale.
out_channels (int): Number of output channels used at each scale.
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
convolution layer. Defaults to None.
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
layer. Defaults to dict(type='BN').
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
list[dict], optional): Initialization config dict.
Defaults to None.
"""

def __init__(self,
in_channels: int,
out_channels: int,
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(type='BN'),
init_cfg: OptMultiConfig = None):
super().__init__(init_cfg=init_cfg)
assert out_channels % 4 == 0

self.in_channels = in_channels
self.out_channels = out_channels

self.conv3x3 = ConvModule(
self.in_channels,
self.out_channels // 2,
3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)

self.context_module = SSHContextModule(
in_channels=self.in_channels,
out_channels=self.out_channels,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg)

def forward(self, x: torch.Tensor) -> torch.Tensor:
conv3x3 = self.conv3x3(x)
conv5x5, conv7x7 = self.context_module(x)
out = torch.cat([conv3x3, conv5x5, conv7x7], dim=1)
out = F.relu(out)

return out


@MODELS.register_module()
class SSH(BaseModule):
"""`SSH Neck` used in `SSH: Single Stage Headless Face Detector.
<https://arxiv.org/pdf/1708.03979.pdf>`_.
Args:
num_scales (int): The number of scales / stages.
in_channels (list[int]): The number of input channels per scale.
out_channels (list[int]): The number of output channels per scale.
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
convolution layer. Defaults to None.
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
layer. Defaults to dict(type='BN').
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
list[dict], optional): Initialization config dict.
Example:
>>> import torch
>>> in_channels = [8, 16, 32, 64]
>>> out_channels = [16, 32, 64, 128]
>>> scales = [340, 170, 84, 43]
>>> inputs = [torch.rand(1, c, s, s)
... for c, s in zip(in_channels, scales)]
>>> self = SSH(num_scales=4, in_channels=in_channels,
... out_channels=out_channels)
>>> outputs = self.forward(inputs)
>>> for i in range(len(outputs)):
... print(f'outputs[{i}].shape = {outputs[i].shape}')
outputs[0].shape = torch.Size([1, 16, 340, 340])
outputs[1].shape = torch.Size([1, 32, 170, 170])
outputs[2].shape = torch.Size([1, 64, 84, 84])
outputs[3].shape = torch.Size([1, 128, 43, 43])
"""

def __init__(self,
num_scales: int,
in_channels: List[int],
out_channels: List[int],
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(type='BN'),
init_cfg: OptMultiConfig = dict(
type='Xavier', layer='Conv2d', distribution='uniform')):
super().__init__(init_cfg=init_cfg)
assert (num_scales == len(in_channels) == len(out_channels))
self.num_scales = num_scales
self.in_channels = in_channels
self.out_channels = out_channels

for idx in range(self.num_scales):
in_c, out_c = self.in_channels[idx], self.out_channels[idx]
self.add_module(
f'ssh_module{idx}',
SSHDetModule(
in_channels=in_c,
out_channels=out_c,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg))

def forward(self, inputs: Tuple[torch.Tensor]) -> tuple:
assert len(inputs) == self.num_scales

outs = []
for idx, x in enumerate(inputs):
ssh_module = getattr(self, f'ssh_module{idx}')
out = ssh_module(x)
outs.append(out)

return tuple(outs)
22 changes: 21 additions & 1 deletion tests/test_models/test_necks/test_necks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch.nn.modules.batchnorm import _BatchNorm

from mmdet.models.necks import (FPG, FPN, FPN_CARAFE, NASFCOS_FPN, NASFPN,
from mmdet.models.necks import (FPG, FPN, FPN_CARAFE, NASFCOS_FPN, NASFPN, SSH,
YOLOXPAFPN, ChannelMapper, DilatedEncoder,
DyHead, SSDNeck, YOLOV3Neck)

Expand Down Expand Up @@ -629,3 +629,23 @@ def test_nasfcos_fpn():
start_level=1,
end_level=2,
num_outs=3)


def test_ssh_neck():
"""Tests ssh."""
s = 64
in_channels = [8, 16, 32, 64]
feat_sizes = [s // 2**i for i in range(4)] # [64, 32, 16, 8]
out_channels = [16, 32, 64, 128]
ssh_model = SSH(
num_scales=4, in_channels=in_channels, out_channels=out_channels)

feats = [
torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
for i in range(len(in_channels))
]
outs = ssh_model(feats)
assert len(outs) == len(feats)
for i in range(len(outs)):
assert outs[i].shape == \
(1, out_channels[i], feat_sizes[i], feat_sizes[i])

0 comments on commit 28513d5

Please sign in to comment.