diff --git a/mmcv/cnn/bricks/transformer.py b/mmcv/cnn/bricks/transformer.py index 84956ef508..5fc9c4ea0b 100644 --- a/mmcv/cnn/bricks/transformer.py +++ b/mmcv/cnn/bricks/transformer.py @@ -588,24 +588,21 @@ def __init__(self, ffn_drop=0., dropout_layer=None, add_identity=True, - init_cfg=None, - **kwargs): + init_cfg=None): super().__init__(init_cfg) assert num_fcs >= 2, 'num_fcs should be no less ' \ f'than 2. got {num_fcs}.' self.embed_dims = embed_dims self.feedforward_channels = feedforward_channels self.num_fcs = num_fcs - self.act_cfg = act_cfg - self.activate = build_activation_layer(act_cfg) layers = [] in_channels = embed_dims for _ in range(num_fcs - 1): layers.append( Sequential( - Linear(in_channels, feedforward_channels), self.activate, - nn.Dropout(ffn_drop))) + Linear(in_channels, feedforward_channels), + build_activation_layer(act_cfg), nn.Dropout(ffn_drop))) in_channels = feedforward_channels layers.append(Linear(feedforward_channels, embed_dims)) layers.append(nn.Dropout(ffn_drop))