Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Enhance feature extraction function. #593

Merged
merged 12 commits into from
Dec 17, 2021
2 changes: 2 additions & 0 deletions configs/_base_/models/mobilenet_v3_large_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@
dropout_rate=0.2,
act_cfg=dict(type='HSwish'),
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
init_cfg=dict(
type='Normal', layer='Linear', mean=0., std=0.01, bias=0.),
topk=(1, 5)))
2 changes: 2 additions & 0 deletions configs/_base_/models/mobilenet_v3_small_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@
dropout_rate=0.2,
act_cfg=dict(type='HSwish'),
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
init_cfg=dict(
type='Normal', layer='Linear', mean=0., std=0.01, bias=0.),
topk=(1, 5)))
2 changes: 1 addition & 1 deletion configs/mobilenet_v3/mobilenet-v3-small_8xb32_in1k.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# - modify: RandomErasing use RE-M instead of RE-0

_base_ = [
'../_base_/models/mobilenet-v3-small_8xb32_in1k.py',
'../_base_/models/mobilenet_v3_small_imagenet.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/default_runtime.py'
]
Expand Down
7 changes: 4 additions & 3 deletions mmcls/models/classifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ def with_head(self):
return hasattr(self, 'head') and self.head is not None

@abstractmethod
def extract_feat(self, imgs):
def extract_feat(self, imgs, stage=None):
pass

def extract_feats(self, imgs):
def extract_feats(self, imgs, stage=None):
mzr1996 marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(imgs, list)
kwargs = {} if stage is None else {'stage': stage}
for img in imgs:
yield self.extract_feat(img)
yield self.extract_feat(img, **kwargs)

@abstractmethod
def forward_train(self, imgs, **kwargs):
Expand Down
18 changes: 15 additions & 3 deletions mmcls/models/classifiers/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,12 @@ def __init__(self,
cfg['prob'] = cutmix_prob
self.augments = Augments(cfg)

def extract_feat(self, img):
def extract_feat(self, img, stage='neck'):
"""Directly extract features from the backbone + neck."""
assert stage in ['backbone', 'neck', 'pre_logits'], \
(f'Invalid output stage "{stage}", please choose from "backbone", '
'"neck" and "pre_logits"')

x = self.backbone(img)
if self.return_tuple:
if not isinstance(x, tuple):
Expand All @@ -83,8 +87,16 @@ def extract_feat(self, img):
else:
if isinstance(x, tuple):
x = x[-1]
if stage == 'backbone':
return x

if self.with_neck:
x = self.neck(x)
if stage == 'neck':
return x

if self.with_head and hasattr(self.head, 'pre_logits'):
x = self.head.pre_logits(x)
return x

def forward_train(self, img, gt_label, **kwargs):
Expand Down Expand Up @@ -122,12 +134,12 @@ def forward_train(self, img, gt_label, **kwargs):

return losses

def simple_test(self, img, img_metas):
def simple_test(self, img, img_metas=None, **kwargs):
"""Test without augmentation."""
x = self.extract_feat(img)

try:
res = self.head.simple_test(x)
res = self.head.simple_test(x, **kwargs)
except TypeError as e:
if 'not tuple' in str(e) and self.return_tuple:
return TypeError(
Expand Down
27 changes: 24 additions & 3 deletions mmcls/models/heads/cls_head.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -62,14 +64,33 @@ def forward_train(self, cls_score, gt_label, **kwargs):
losses = self.loss(cls_score, gt_label, **kwargs)
return losses

def simple_test(self, cls_score):
def pre_logits(self, x):
if isinstance(x, tuple):
x = x[-1]

warnings.warn(
'The input of ClsHead should be already logits. '
'Please modify the backbone if you want to get pre-logits feature.'
)
return x

def simple_test(self, cls_score, softmax=True, post_process=True):
"""Test without augmentation."""
if isinstance(cls_score, tuple):
cls_score = cls_score[-1]
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
return self.post_process(pred)

if softmax:
pred = (
F.softmax(cls_score, dim=1) if cls_score is not None else None)
else:
pred = cls_score

if post_process:
return self.post_process(pred)
else:
return pred

def post_process(self, pred):
on_trace = is_tracing()
Expand Down
33 changes: 21 additions & 12 deletions mmcls/models/heads/conformer_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class ConformerHead(ClsHead):
category.
in_channels (int): Number of channels in the input feature map.
init_cfg (dict | optional): The extra init config of layers.
Defaults to use dict(type='Normal', layer='Linear', std=0.01).
Defaults to use ``dict(type='Normal', layer='Linear', std=0.01)``.
"""

def __init__(
Expand Down Expand Up @@ -55,25 +55,34 @@ def init_weights(self):
else:
self.apply(self._init_weights)

def simple_test(self, x):
"""Test without augmentation."""
def pre_logits(self, x):
if isinstance(x, tuple):
x = x[-1]
assert isinstance(x,
list) # There are two outputs in the Conformer model
return x

def simple_test(self, x, softmax=True, post_process=True):
"""Test without augmentation."""
x = self.pre_logits(x)
# There are two outputs in the Conformer model
assert isinstance(x, list)

conv_cls_score = self.conv_cls_head(x[0])
tran_cls_score = self.trans_cls_head(x[1])

cls_score = conv_cls_score + tran_cls_score

pred = F.softmax(cls_score, dim=1) if cls_score is not None else None

return self.post_process(pred)
if softmax:
cls_score = conv_cls_score + tran_cls_score
pred = (
F.softmax(cls_score, dim=1) if cls_score is not None else None)
if post_process:
pred = self.post_process(pred)
else:
pred = [conv_cls_score, tran_cls_score]
if post_process:
pred = list(map(self.post_process, pred))
return pred

def forward_train(self, x, gt_label):
if isinstance(x, tuple):
x = x[-1]
x = self.pre_logits(x)
assert isinstance(x, list) and len(x) == 2, \
'There should be two outputs in the Conformer model'

Expand Down
23 changes: 17 additions & 6 deletions mmcls/models/heads/linear_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,31 @@ def __init__(self,

self.fc = nn.Linear(self.in_channels, self.num_classes)

def simple_test(self, x):
"""Test without augmentation."""
def pre_logits(self, x):
if isinstance(x, tuple):
x = x[-1]
return x

def simple_test(self, x, softmax=True, post_process=True):
"""Test without augmentation."""
x = self.pre_logits(x)
cls_score = self.fc(x)
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None

return self.post_process(pred)
if softmax:
pred = (
F.softmax(cls_score, dim=1) if cls_score is not None else None)
else:
pred = cls_score

if post_process:
return self.post_process(pred)
else:
return pred

def forward_train(self, x, gt_label, **kwargs):
if isinstance(x, tuple):
x = x[-1]
x = self.pre_logits(x)
cls_score = self.fc(x)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses
26 changes: 22 additions & 4 deletions mmcls/models/heads/multi_label_head.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F

from ..builder import HEADS, build_loss
from ..utils import is_tracing
Expand Down Expand Up @@ -47,14 +46,33 @@ def forward_train(self, cls_score, gt_label, **kwargs):
losses = self.loss(cls_score, gt_label, **kwargs)
return losses

def simple_test(self, x):
def pre_logits(self, x):
if isinstance(x, tuple):
x = x[-1]

from mmcls.utils import get_root_logger
logger = get_root_logger()
logger.warning(
'The input of MultiLabelClsHead should be already logits. '
'Please modify the backbone if you want to get pre-logits feature.'
)
return x

def simple_test(self, x, sigmoid=True, post_process=True):
if isinstance(x, tuple):
mzr1996 marked this conversation as resolved.
Show resolved Hide resolved
x = x[-1]
if isinstance(x, list):
x = sum(x) / float(len(x))
pred = F.sigmoid(x) if x is not None else None

return self.post_process(pred)
if sigmoid:
pred = torch.sigmoid(x) if x is not None else None
else:
pred = x

if post_process:
return self.post_process(pred)
else:
return pred

def post_process(self, pred):
on_trace = is_tracing()
Expand Down
24 changes: 17 additions & 7 deletions mmcls/models/heads/multi_label_linear_head.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..builder import HEADS
from .multi_label_head import MultiLabelClsHead
Expand Down Expand Up @@ -39,21 +39,31 @@ def __init__(self,

self.fc = nn.Linear(self.in_channels, self.num_classes)

def forward_train(self, x, gt_label, **kwargs):
def pre_logits(self, x):
if isinstance(x, tuple):
x = x[-1]
return x

def forward_train(self, x, gt_label, **kwargs):
x = self.pre_logits(x)
gt_label = gt_label.type_as(x)
cls_score = self.fc(x)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses

def simple_test(self, x):
def simple_test(self, x, sigmoid=True, post_process=True):
"""Test without augmentation."""
if isinstance(x, tuple):
x = x[-1]
x = self.pre_logits(x)
cls_score = self.fc(x)
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.sigmoid(cls_score) if cls_score is not None else None

return self.post_process(pred)
if sigmoid:
pred = torch.sigmoid(cls_score) if cls_score is not None else None
else:
pred = cls_score

if post_process:
return self.post_process(pred)
else:
return pred
41 changes: 26 additions & 15 deletions mmcls/models/heads/stacked_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,7 @@ def __init__(self,
self._init_layers()

def _init_layers(self):
self.layers = ModuleList(
init_cfg=dict(
type='Normal', layer='Linear', mean=0., std=0.01, bias=0.))
self.layers = ModuleList()
in_channels = self.in_channels
for hidden_channels in self.mid_channels:
self.layers.append(
Expand All @@ -114,24 +112,37 @@ def _init_layers(self):
def init_weights(self):
self.layers.init_weights()

def simple_test(self, x):
"""Test without augmentation."""
def pre_logits(self, x):
if isinstance(x, tuple):
x = x[-1]
cls_score = x
for layer in self.layers:
cls_score = layer(cls_score)
for layer in self.layers[:-1]:
x = layer(x)
return x

@property
def fc(self):
return self.layers[-1]

def simple_test(self, x, softmax=True, post_process=True):
"""Test without augmentation."""
x = self.pre_logits(x)
cls_score = self.fc(x)
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None

return self.post_process(pred)
if softmax:
pred = (
F.softmax(cls_score, dim=1) if cls_score is not None else None)
else:
pred = cls_score

if post_process:
return self.post_process(pred)
else:
return pred

def forward_train(self, x, gt_label, **kwargs):
if isinstance(x, tuple):
x = x[-1]
cls_score = x
for layer in self.layers:
cls_score = layer(cls_score)
x = self.pre_logits(x)
cls_score = self.fc(x)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses
Loading