Skip to content

Commit

Permalink
[Enhance] Improve downstream repositories compatibility (#421)
Browse files Browse the repository at this point in the history
* Defaults to return tuple in all backbones.

* Compat downstream of swin transformer.

* Support tuple input for multi label head and stacked head.

* Fix backbone unit tests for tuple output.

* Add downstream inference unit tests for mmdet.

* Update gitignore

* Add unit tests for `return_tuple` option

* Add unit tests for head input tuple.

* Add warning in `simple_test`

* Add TIMMBackbone return tuple.

* Modify timm backbone unit test.
  • Loading branch information
mzr1996 authored Sep 8, 2021
1 parent 5cfaed6 commit a8f4f82
Show file tree
Hide file tree
Showing 41 changed files with 384 additions and 146 deletions.
6 changes: 3 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,14 @@ venv.bak/
.mypy_cache/

# custom
data
/data
.vscode
.idea
*.pkl
*.pkl.json
*.log.json
work_dirs/
mmcls/.mim
/work_dirs
/mmcls/.mim

# Pytorch
*.pth
2 changes: 1 addition & 1 deletion configs/_base_/models/swin_transformer/base_224.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
type='ImageClassifier',
backbone=dict(
type='SwinTransformer', arch='base', img_size=224, drop_path_rate=0.5),
neck=dict(type='GlobalAveragePooling', dim=1),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
Expand Down
2 changes: 1 addition & 1 deletion configs/_base_/models/swin_transformer/base_384.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
arch='base',
img_size=384,
stage_cfgs=dict(block_cfgs=dict(window_size=12))),
neck=dict(type='GlobalAveragePooling', dim=1),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
Expand Down
2 changes: 1 addition & 1 deletion configs/_base_/models/swin_transformer/large_224.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
model = dict(
type='ImageClassifier',
backbone=dict(type='SwinTransformer', arch='large', img_size=224),
neck=dict(type='GlobalAveragePooling', dim=1),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
Expand Down
2 changes: 1 addition & 1 deletion configs/_base_/models/swin_transformer/large_384.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
arch='large',
img_size=384,
stage_cfgs=dict(block_cfgs=dict(window_size=12))),
neck=dict(type='GlobalAveragePooling', dim=1),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
Expand Down
2 changes: 1 addition & 1 deletion configs/_base_/models/swin_transformer/small_224.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
backbone=dict(
type='SwinTransformer', arch='small', img_size=224,
drop_path_rate=0.3),
neck=dict(type='GlobalAveragePooling', dim=1),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
Expand Down
2 changes: 1 addition & 1 deletion configs/_base_/models/swin_transformer/tiny_224.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
type='ImageClassifier',
backbone=dict(
type='SwinTransformer', arch='tiny', img_size=224, drop_path_rate=0.2),
neck=dict(type='GlobalAveragePooling', dim=1),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
Expand Down
2 changes: 1 addition & 1 deletion mmcls/models/backbones/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ def forward(self, x):
x = x.view(x.size(0), 256 * 6 * 6)
x = self.classifier(x)

return x
return (x, )
2 changes: 1 addition & 1 deletion mmcls/models/backbones/lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ def forward(self, x):
if self.num_classes > 0:
x = self.classifier(x.squeeze())

return x
return (x, )
5 changes: 1 addition & 4 deletions mmcls/models/backbones/mobilenet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,7 @@ def forward(self, x):
if i in self.out_indices:
outs.append(x)

if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
return tuple(outs)

def _freeze_stages(self):
if self.frozen_stages >= 0:
Expand Down
5 changes: 1 addition & 4 deletions mmcls/models/backbones/mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,7 @@ def forward(self, x):
if i in self.out_indices:
outs.append(x)

if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
return tuple(outs)

def _freeze_stages(self):
for i in range(0, self.frozen_stages + 1):
Expand Down
5 changes: 1 addition & 4 deletions mmcls/models/backbones/regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,4 @@ def forward(self, x):
if i in self.out_indices:
outs.append(x)

if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
return tuple(outs)
5 changes: 1 addition & 4 deletions mmcls/models/backbones/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,10 +624,7 @@ def forward(self, x):
x = res_layer(x)
if i in self.out_indices:
outs.append(x)
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
return tuple(outs)

def train(self, mode=True):
super(ResNet, self).train(mode)
Expand Down
5 changes: 1 addition & 4 deletions mmcls/models/backbones/resnet_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,4 @@ def forward(self, x):
x = res_layer(x)
if i in self.out_indices:
outs.append(x)
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
return tuple(outs)
5 changes: 1 addition & 4 deletions mmcls/models/backbones/shufflenet_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,7 @@ def forward(self, x):
if i in self.out_indices:
outs.append(x)

if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
return tuple(outs)

def train(self, mode=True):
super(ShuffleNetV1, self).train(mode)
Expand Down
5 changes: 1 addition & 4 deletions mmcls/models/backbones/shufflenet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,7 @@ def forward(self, x):
if i in self.out_indices:
outs.append(x)

if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
return tuple(outs)

def train(self, mode=True):
super(ShuffleNetV2, self).train(mode)
Expand Down
75 changes: 63 additions & 12 deletions mmcls/models/backbones/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def __init__(self,
if not isinstance(block_cfgs, Sequence):
block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)]

self.embed_dims = embed_dims
self.input_resolution = input_resolution
self.blocks = ModuleList()
for i in range(depth):
_block_cfg = {
Expand Down Expand Up @@ -171,6 +173,20 @@ def forward(self, x):
x = self.downsample(x)
return x

@property
def out_resolution(self):
if self.downsample:
return self.downsample.output_resolution
else:
return self.input_resolution

@property
def out_channels(self):
if self.downsample:
return self.downsample.out_channels
else:
return self.embed_dims


@BACKBONES.register_module()
class SwinTransformer(BaseBackbone):
Expand Down Expand Up @@ -239,12 +255,15 @@ class SwinTransformer(BaseBackbone):
'num_heads': [6, 12, 24, 48]}),
} # yapf: disable

_version = 2

def __init__(self,
arch='T',
img_size=224,
in_channels=3,
drop_rate=0.,
drop_path_rate=0.1,
out_indices=(3, ),
use_abs_pos_embed=False,
auto_pad=False,
norm_cfg=dict(type='LN'),
Expand All @@ -268,6 +287,7 @@ def __init__(self,
self.depths = self.arch_settings['depths']
self.num_heads = self.arch_settings['num_heads']
self.num_layers = len(self.depths)
self.out_indices = out_indices
self.use_abs_pos_embed = use_abs_pos_embed
self.auto_pad = auto_pad

Expand Down Expand Up @@ -321,18 +341,25 @@ def __init__(self,
self.stages.append(stage)

dpr = dpr[depth:]
if downsample:
embed_dims = stage.downsample.out_channels
input_resolution = stage.downsample.output_resolution
embed_dims = stage.out_channels
input_resolution = stage.out_resolution

if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
else:
self.norm = None
for i in out_indices:
if norm_cfg is not None:
norm_layer = build_norm_layer(norm_cfg, embed_dims)[1]
else:
norm_layer = nn.Identity()

self.add_module(f'norm{i}', norm_layer)

def init_weights(self):
super(SwinTransformer, self).init_weights()

if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
return

if self.use_abs_pos_embed:
trunc_normal_(self.absolute_pos_embed, std=0.02)

Expand All @@ -342,9 +369,33 @@ def forward(self, x):
x = x + self.absolute_pos_embed
x = self.drop_after_pos(x)

for stage in self.stages:
outs = []
for i, stage in enumerate(self.stages):
x = stage(x)

x = self.norm(x) if self.norm else x

return x.transpose(1, 2)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
out = norm_layer(x)
out = out.view(-1, *stage.out_resolution,
stage.out_channels).permute(0, 3, 1,
2).contiguous()
outs.append(out)

return tuple(outs)

def _load_from_state_dict(self, state_dict, prefix, local_metadata, *args,
**kwargs):
"""load checkpoints."""
# Names of some parameters in has been changed.
version = local_metadata.get('version', None)
if (version is None
or version < 2) and self.__class__ is SwinTransformer:
final_stage_num = len(self.stages) - 1
state_dict_keys = list(state_dict.keys())
for k in state_dict_keys:
if k.startswith('norm.') or k.startswith('backbone.norm.'):
convert_key = k.replace('norm.', f'norm{final_stage_num}.')
state_dict[convert_key] = state_dict[k]
del state_dict[k]

super()._load_from_state_dict(state_dict, prefix, local_metadata,
*args, **kwargs)
2 changes: 1 addition & 1 deletion mmcls/models/backbones/timm_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ def __init__(

def forward(self, x):
features = self.timm_model.forward_features(x)
return features
return (features, )
6 changes: 2 additions & 4 deletions mmcls/models/backbones/vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,8 @@ def forward(self, x):
x = x.view(x.size(0), -1)
x = self.classifier(x)
outs.append(x)
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)

return tuple(outs)

def _freeze_stages(self):
vgg_layers = getattr(self, self.module_name)
Expand Down
49 changes: 44 additions & 5 deletions mmcls/models/classifiers/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@ def __init__(self,
key, please consider using init_cfg')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)

return_tuple = backbone.pop('return_tuple', True)
self.backbone = build_backbone(backbone)
if return_tuple is False:
warnings.warn(
'The `return_tuple` is a temporary arg, we will force to '
'return tuple in the future. Please handle tuple in your '
'custom neck or head.', DeprecationWarning)
self.return_tuple = return_tuple

if neck is not None:
self.neck = build_neck(neck)
Expand Down Expand Up @@ -64,6 +71,17 @@ def __init__(self,
def extract_feat(self, img):
"""Directly extract features from the backbone + neck."""
x = self.backbone(img)
if self.return_tuple:
if not isinstance(x, tuple):
x = (x, )
warnings.simplefilter('once')
warnings.warn(
'We will force all backbones to return a tuple in the '
'future. Please check your backbone and wrap the output '
'as a tuple.', DeprecationWarning)
else:
if isinstance(x, tuple):
x = x[-1]
if self.with_neck:
x = self.neck(x)
return x
Expand All @@ -87,15 +105,36 @@ def forward_train(self, img, gt_label, **kwargs):
x = self.extract_feat(img)

losses = dict()
loss = self.head.forward_train(x, gt_label)
try:
loss = self.head.forward_train(x, gt_label)
except TypeError as e:
if 'not tuple' in str(e) and self.return_tuple:
return TypeError(
'Seems the head cannot handle tuple input. We have '
'changed all backbones\' output to a tuple. Please '
'update your custom head\'s forward function. '
'Temporarily, you can set "return_tuple=False" in '
'your backbone config to disable this feature.')
raise e

losses.update(loss)

return losses

def simple_test(self, img, img_metas):
"""Test without augmentation."""
x = self.extract_feat(img)
x_dims = len(x.shape)
if x_dims == 1:
x.unsqueeze_(0)
return self.head.simple_test(x)

try:
res = self.head.simple_test(x)
except TypeError as e:
if 'not tuple' in str(e) and self.return_tuple:
return TypeError(
'Seems the head cannot handle tuple input. We have '
'changed all backbones\' output to a tuple. Please '
'update your custom head\'s forward function. '
'Temporarily, you can set "return_tuple=False" in '
'your backbone config to disable this feature.')
raise e

return res
4 changes: 4 additions & 0 deletions mmcls/models/heads/cls_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,15 @@ def loss(self, cls_score, gt_label):
return losses

def forward_train(self, cls_score, gt_label):
if isinstance(cls_score, tuple):
cls_score = cls_score[-1]
losses = self.loss(cls_score, gt_label)
return losses

def simple_test(self, cls_score):
"""Test without augmentation."""
if isinstance(cls_score, tuple):
cls_score = cls_score[-1]
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
Expand Down
Loading

0 comments on commit a8f4f82

Please sign in to comment.