Skip to content

Commit

Permalink
Update mmdet/models/necks/ssh.py
Browse files Browse the repository at this point in the history
  • Loading branch information
MambaWong committed Oct 10, 2022
1 parent ffb65b4 commit 7257c9b
Showing 1 changed file with 29 additions and 26 deletions.
55 changes: 29 additions & 26 deletions mmdet/models/necks/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mmengine.model import BaseModule

from mmdet.registry import MODELS
from mmdet.utils import ConfigType, OptConfigType
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig


class SSHModule(BaseModule):
Expand All @@ -17,19 +17,21 @@ class SSHModule(BaseModule):
Args:
in_channels (int): Number of input channels used at each scale.
out_channels (int): Number of output channels used at each scale.
conv_cfg (dict, optional): Config dict for convolution layer.
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
convolution layer. Defaults to None.
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
layer. Defaults to dict(type='BN').
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
list[dict], optional): Initialization config dict.
Defaults to None.
norm_cfg (dict): Config dict for normalization layer.
Defaults to dict(type='BN').
init_cfg (dict or list[dict], optional): Initialization config dict.
"""

def __init__(self,
in_channels,
out_channels,
conv_cfg=None,
in_channels: int,
out_channels: int,
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(type='BN'),
init_cfg=None):
init_cfg: OptMultiConfig = None):
super().__init__(init_cfg=init_cfg)
assert out_channels % 4 == 0

Expand Down Expand Up @@ -107,23 +109,23 @@ class SSH(BaseModule):
Args:
num_scales (int): The number of scales / stages.
in_channels (List[int]): The number of input channels per scale.
out_channels (List[int]): The number of output channels per scale.
conv_cfg (dict, optional): Config dict for convolution layer.
Defaults to None.
norm_cfg (dict, optional): Dictionary to construct and config norm
layer. Defaults to dict(type='BN', requires_grad=True)
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
in_channels (list[int]): The number of input channels per scale.
out_channels (list[int]): The number of output channels per scale.
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
convolution layer. Defaults to None.
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
layer. Defaults to dict(type='BN').
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
list[dict], optional): Initialization config dict.
"""

def __init__(self,
num_scales,
num_scales: int,
in_channels: List[int],
out_channels: List[int],
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(type='BN'),
init_cfg: ConfigType = dict(
init_cfg: OptMultiConfig = dict(
type='Xavier', layer='Conv2d', distribution='uniform')):
super().__init__(init_cfg=init_cfg)
assert (num_scales == len(in_channels) == len(out_channels))
Expand Down Expand Up @@ -157,19 +159,20 @@ class RetinaSSH(BaseModule):
Args:
in_channels (int): Number of input channels used at each scale.
out_channels (int): Number of output channels used at each scale.
conv_cfg (dict, optional): Config dict for convolution layer.
Defaults to None.
norm_cfg (dict): Config dict for normalization layer.
Defaults to dict(type='BN').
init_cfg (dict or list[dict], optional): Initialization config dict.
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
convolution layer. Defaults to None.
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
layer. Defaults to dict(type='BN').
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
list[dict], optional): Initialization config dict.
"""

def __init__(self,
in_channels: int,
out_channels: int,
conv_cfg=None,
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(type='BN'),
init_cfg: ConfigType = dict(
init_cfg: OptMultiConfig = dict(
type='Xavier', layer='Conv2d', distribution='uniform')):
super().__init__(init_cfg=init_cfg)

Expand Down

0 comments on commit 7257c9b

Please sign in to comment.