-
Notifications
You must be signed in to change notification settings - Fork 9.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add SSH Module * Update mmdet/models/necks/ssh.py Co-authored-by: BigDong <yudongwang1226@gmail.com> * Update mmdet/models/necks/ssh.py Co-authored-by: BigDong <yudongwang1226@gmail.com> * Update mmdet/models/necks/ssh.py Co-authored-by: BigDong <yudongwang1226@gmail.com> * Update mmdet/models/necks/ssh.py Co-authored-by: BigDong <yudongwang1226@gmail.com> * Update mmdet/models/necks/ssh.py Co-authored-by: BigDong <yudongwang1226@gmail.com> * Update mmdet/models/necks/ssh.py Co-authored-by: BigDong <yudongwang1226@gmail.com> * Update mmdet/models/necks/ssh.py Co-authored-by: BigDong <yudongwang1226@gmail.com> * Update mmdet/models/necks/ssh.py * Update mmdet/models/necks/ssh.py * fix variable names * decompose the SSHModule into SSHContextModule and SSHDetModule * remove RetinaSSH * add ssh unit test * add `init_cfg` parameter to `SSHContextModule` * add ssh example Co-authored-by: BigDong <yudongwang1226@gmail.com>
- Loading branch information
1 parent
380d936
commit 28513d5
Showing
3 changed files
with
239 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import List, Tuple | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from mmcv.cnn import ConvModule | ||
from mmengine.model import BaseModule | ||
|
||
from mmdet.registry import MODELS | ||
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig | ||
|
||
|
||
class SSHContextModule(BaseModule): | ||
"""This is an implementation of `SSH context module` described in `SSH: | ||
Single Stage Headless Face Detector. | ||
<https://arxiv.org/pdf/1708.03979.pdf>`_. | ||
Args: | ||
in_channels (int): Number of input channels used at each scale. | ||
out_channels (int): Number of output channels used at each scale. | ||
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for | ||
convolution layer. Defaults to None. | ||
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization | ||
layer. Defaults to dict(type='BN'). | ||
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or | ||
list[dict], optional): Initialization config dict. | ||
Defaults to None. | ||
""" | ||
|
||
def __init__(self, | ||
in_channels: int, | ||
out_channels: int, | ||
conv_cfg: OptConfigType = None, | ||
norm_cfg: ConfigType = dict(type='BN'), | ||
init_cfg: OptMultiConfig = None): | ||
super().__init__(init_cfg=init_cfg) | ||
assert out_channels % 4 == 0 | ||
|
||
self.in_channels = in_channels | ||
self.out_channels = out_channels | ||
|
||
self.conv5x5_1 = ConvModule( | ||
self.in_channels, | ||
self.out_channels // 4, | ||
3, | ||
stride=1, | ||
padding=1, | ||
conv_cfg=conv_cfg, | ||
norm_cfg=norm_cfg, | ||
) | ||
|
||
self.conv5x5_2 = ConvModule( | ||
self.out_channels // 4, | ||
self.out_channels // 4, | ||
3, | ||
stride=1, | ||
padding=1, | ||
conv_cfg=conv_cfg, | ||
norm_cfg=norm_cfg, | ||
act_cfg=None) | ||
|
||
self.conv7x7_2 = ConvModule( | ||
self.out_channels // 4, | ||
self.out_channels // 4, | ||
3, | ||
stride=1, | ||
padding=1, | ||
conv_cfg=conv_cfg, | ||
norm_cfg=norm_cfg, | ||
) | ||
|
||
self.conv7x7_3 = ConvModule( | ||
self.out_channels // 4, | ||
self.out_channels // 4, | ||
3, | ||
stride=1, | ||
padding=1, | ||
conv_cfg=conv_cfg, | ||
norm_cfg=norm_cfg, | ||
act_cfg=None, | ||
) | ||
|
||
def forward(self, x: torch.Tensor) -> tuple: | ||
conv5x5_1 = self.conv5x5_1(x) | ||
conv5x5 = self.conv5x5_2(conv5x5_1) | ||
conv7x7_2 = self.conv7x7_2(conv5x5_1) | ||
conv7x7 = self.conv7x7_3(conv7x7_2) | ||
|
||
return (conv5x5, conv7x7) | ||
|
||
|
||
class SSHDetModule(BaseModule): | ||
"""This is an implementation of `SSH detection module` described in `SSH: | ||
Single Stage Headless Face Detector. | ||
<https://arxiv.org/pdf/1708.03979.pdf>`_. | ||
Args: | ||
in_channels (int): Number of input channels used at each scale. | ||
out_channels (int): Number of output channels used at each scale. | ||
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for | ||
convolution layer. Defaults to None. | ||
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization | ||
layer. Defaults to dict(type='BN'). | ||
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or | ||
list[dict], optional): Initialization config dict. | ||
Defaults to None. | ||
""" | ||
|
||
def __init__(self, | ||
in_channels: int, | ||
out_channels: int, | ||
conv_cfg: OptConfigType = None, | ||
norm_cfg: ConfigType = dict(type='BN'), | ||
init_cfg: OptMultiConfig = None): | ||
super().__init__(init_cfg=init_cfg) | ||
assert out_channels % 4 == 0 | ||
|
||
self.in_channels = in_channels | ||
self.out_channels = out_channels | ||
|
||
self.conv3x3 = ConvModule( | ||
self.in_channels, | ||
self.out_channels // 2, | ||
3, | ||
stride=1, | ||
padding=1, | ||
conv_cfg=conv_cfg, | ||
norm_cfg=norm_cfg, | ||
act_cfg=None) | ||
|
||
self.context_module = SSHContextModule( | ||
in_channels=self.in_channels, | ||
out_channels=self.out_channels, | ||
conv_cfg=conv_cfg, | ||
norm_cfg=norm_cfg) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
conv3x3 = self.conv3x3(x) | ||
conv5x5, conv7x7 = self.context_module(x) | ||
out = torch.cat([conv3x3, conv5x5, conv7x7], dim=1) | ||
out = F.relu(out) | ||
|
||
return out | ||
|
||
|
||
@MODELS.register_module() | ||
class SSH(BaseModule): | ||
"""`SSH Neck` used in `SSH: Single Stage Headless Face Detector. | ||
<https://arxiv.org/pdf/1708.03979.pdf>`_. | ||
Args: | ||
num_scales (int): The number of scales / stages. | ||
in_channels (list[int]): The number of input channels per scale. | ||
out_channels (list[int]): The number of output channels per scale. | ||
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for | ||
convolution layer. Defaults to None. | ||
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization | ||
layer. Defaults to dict(type='BN'). | ||
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or | ||
list[dict], optional): Initialization config dict. | ||
Example: | ||
>>> import torch | ||
>>> in_channels = [8, 16, 32, 64] | ||
>>> out_channels = [16, 32, 64, 128] | ||
>>> scales = [340, 170, 84, 43] | ||
>>> inputs = [torch.rand(1, c, s, s) | ||
... for c, s in zip(in_channels, scales)] | ||
>>> self = SSH(num_scales=4, in_channels=in_channels, | ||
... out_channels=out_channels) | ||
>>> outputs = self.forward(inputs) | ||
>>> for i in range(len(outputs)): | ||
... print(f'outputs[{i}].shape = {outputs[i].shape}') | ||
outputs[0].shape = torch.Size([1, 16, 340, 340]) | ||
outputs[1].shape = torch.Size([1, 32, 170, 170]) | ||
outputs[2].shape = torch.Size([1, 64, 84, 84]) | ||
outputs[3].shape = torch.Size([1, 128, 43, 43]) | ||
""" | ||
|
||
def __init__(self, | ||
num_scales: int, | ||
in_channels: List[int], | ||
out_channels: List[int], | ||
conv_cfg: OptConfigType = None, | ||
norm_cfg: ConfigType = dict(type='BN'), | ||
init_cfg: OptMultiConfig = dict( | ||
type='Xavier', layer='Conv2d', distribution='uniform')): | ||
super().__init__(init_cfg=init_cfg) | ||
assert (num_scales == len(in_channels) == len(out_channels)) | ||
self.num_scales = num_scales | ||
self.in_channels = in_channels | ||
self.out_channels = out_channels | ||
|
||
for idx in range(self.num_scales): | ||
in_c, out_c = self.in_channels[idx], self.out_channels[idx] | ||
self.add_module( | ||
f'ssh_module{idx}', | ||
SSHDetModule( | ||
in_channels=in_c, | ||
out_channels=out_c, | ||
conv_cfg=conv_cfg, | ||
norm_cfg=norm_cfg)) | ||
|
||
def forward(self, inputs: Tuple[torch.Tensor]) -> tuple: | ||
assert len(inputs) == self.num_scales | ||
|
||
outs = [] | ||
for idx, x in enumerate(inputs): | ||
ssh_module = getattr(self, f'ssh_module{idx}') | ||
out = ssh_module(x) | ||
outs.append(out) | ||
|
||
return tuple(outs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters