Skip to content

Commit

Permalink
[Feature] Support CSRA head. (#881)
Browse files Browse the repository at this point in the history
* Support CSRA head.

* Add CSRA config.

* Improve training scheduler and Update cfg, ckpt, log

* Update metafile

* Rename config files and checkpoints

Co-authored-by: Ezra-Yu <1105212286@qq.com>
Co-authored-by: mzr1996 <mzr1996@163.com>
  • Loading branch information
3 people authored Aug 4, 2022
1 parent b5bb86a commit 1a3d51a
Show file tree
Hide file tree
Showing 8 changed files with 318 additions and 6 deletions.
36 changes: 36 additions & 0 deletions configs/csra/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# CSRA

> [Residual Attention: A Simple but Effective Method for Multi-Label Recognition](https://arxiv.org/abs/2108.02456)
<!-- [ALGORITHM] -->

## Abstract

Multi-label image recognition is a challenging computer vision task of practical use. Progresses in this area, however, are often characterized by complicated methods, heavy computations, and lack of intuitive explanations. To effectively capture different spatial regions occupied by objects from different categories, we propose an embarrassingly simple module, named class-specific residual attention (CSRA). CSRA generates class-specific features for every category by proposing a simple spatial attention score, and then combines it with the class-agnostic average pooling feature. CSRA achieves state-of-the-art results on multilabel recognition, and at the same time is much simpler than them. Furthermore, with only 4 lines of code, CSRA also leads to consistent improvement across many diverse pretrained models and datasets without any extra training. CSRA is both easy to implement and light in computations, which also enjoys intuitive explanations and visualizations.

<div align=center>
<img src="https://user-images.githubusercontent.com/84259897/176982245-3ffcff56-a4ea-4474-9967-bc2b612bbaa3.png" width="80%"/>
</div>

## Results and models

### VOC2007

| Model | Pretrain | Params(M) | Flops(G) | mAP | OF1 (%) | CF1 (%) | Config | Download |
| :------------: | :------------------------------------------------: | :-------: | :------: | :---: | :-----: | :-----: | :-----------------------------------------------: | :-------------------------------------------------: |
| Resnet101-CSRA | [ImageNet-1k](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth) | 23.55 | 4.12 | 94.98 | 90.80 | 89.16 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/csra/resnet101-csra_1xb16_voc07-448px.py) | [model](https://download.openmmlab.com/mmclassification/v0/csra/resnet101-csra_1xb16_voc07-448px_20220722-29efb40a.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/csra/resnet101-csra_1xb16_voc07-448px_20220722-29efb40a.log.json) |

## Citation

```bibtex
@misc{https://doi.org/10.48550/arxiv.2108.02456,
doi = {10.48550/ARXIV.2108.02456},
url = {https://arxiv.org/abs/2108.02456},
author = {Zhu, Ke and Wu, Jianxin},
keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Residual Attention: A Simple but Effective Method for Multi-Label Recognition},
publisher = {arXiv},
year = {2021},
copyright = {arXiv.org perpetual, non-exclusive license}
}
```
29 changes: 29 additions & 0 deletions configs/csra/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
Collections:
- Name: CSRA
Metadata:
Training Data: PASCAL VOC 2007
Architecture:
- Class-specific Residual Attention
Paper:
URL: https://arxiv.org/abs/1911.11929
Title: 'Residual Attention: A Simple but Effective Method for Multi-Label Recognition'
README: configs/csra/README.md
Code:
Version: v0.24.0
URL: https://github.com/open-mmlab/mmclassification/blob/v0.24.0/mmcls/models/heads/multi_label_csra_head.py

Models:
- Name: resnet101-csra_1xb16_voc07-448px
Metadata:
FLOPs: 4120000000
Parameters: 23550000
In Collections: CSRA
Results:
- Dataset: PASCAL VOC 2007
Metrics:
mAP: 94.98
OF1: 90.80
CF1: 89.16
Task: Multi-Label Classification
Weights: https://download.openmmlab.com/mmclassification/v0/csra/resnet101-csra_1xb16_voc07-448px_20220722-29efb40a.pth
Config: configs/csra/resnet101-csra_1xb16_voc07-448px.py
75 changes: 75 additions & 0 deletions configs/csra/resnet101-csra_1xb16_voc07-448px.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
_base_ = ['../_base_/datasets/voc_bs16.py', '../_base_/default_runtime.py']

# Pre-trained Checkpoint Path
checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth' # noqa
# If you want to use the pre-trained weight of ResNet101-CutMix from
# the originary repo(https://github.com/Kevinz-code/CSRA). Script of
# 'tools/convert_models/torchvision_to_mmcls.py' can help you convert weight
# into mmcls format. The mAP result would hit 95.5 by using the weight.
# checkpoint = 'PATH/TO/PRE-TRAINED_WEIGHT'

# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='ResNet',
depth=101,
num_stages=4,
out_indices=(3, ),
style='pytorch',
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint, prefix='backbone')),
neck=None,
head=dict(
type='CSRAClsHead',
num_classes=20,
in_channels=2048,
num_heads=1,
lam=0.1,
loss=dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)))

# dataset setting
img_norm_cfg = dict(mean=[0, 0, 0], std=[255, 255, 255], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', size=448, scale=(0.7, 1.0)),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', size=448),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
# map the difficult examples as negative ones(0)
train=dict(pipeline=train_pipeline, difficult_as_postive=False),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

# optimizer
# the lr of classifier.head is 10 * base_lr, which help convergence.
optimizer = dict(
type='SGD',
lr=0.0002,
momentum=0.9,
weight_decay=0.0001,
paramwise_cfg=dict(custom_keys={'head': dict(lr_mult=10)}))

optimizer_config = dict(grad_clip=None)

# learning policy
lr_config = dict(
policy='step',
step=6,
gamma=0.1,
warmup='linear',
warmup_iters=1,
warmup_ratio=1e-7,
warmup_by_epoch=True)
runner = dict(type='EpochBasedRunner', max_epochs=20)
31 changes: 28 additions & 3 deletions mmcls/datasets/voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,29 @@

@DATASETS.register_module()
class VOC(MultiLabelDataset):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Dataset."""
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Dataset.
Args:
data_prefix (str): the prefix of data path
pipeline (list): a list of dict, where each element represents
a operation defined in `mmcls.datasets.pipelines`
ann_file (str | None): the annotation file. When ann_file is str,
the subclass is expected to read from the ann_file. When ann_file
is None, the subclass is expected to read according to data_prefix
difficult_as_postive (Optional[bool]): Whether to map the difficult
labels as positive. If it set to True, map difficult examples to
positive ones(1), If it set to False, map difficult examples to
negative ones(0). Defaults to None, the difficult labels will be
set to '-1'.
"""

CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
'tvmonitor')

def __init__(self, **kwargs):
def __init__(self, difficult_as_postive=None, **kwargs):
self.difficult_as_postive = difficult_as_postive
super(VOC, self).__init__(**kwargs)
if 'VOC2007' in self.data_prefix:
self.year = 2007
Expand Down Expand Up @@ -55,9 +70,19 @@ def load_annotations(self):
labels.append(label)

gt_label = np.zeros(len(self.CLASSES))
# set difficult example first, then set postivate examples.
# The order cannot be swapped for the case where multiple objects
# of the same kind exist and some are difficult.
gt_label[labels_difficult] = -1
if self.difficult_as_postive is None:
# map difficult examples to -1,
# it may be used in evaluation to ignore difficult targets.
gt_label[labels_difficult] = -1
elif self.difficult_as_postive:
# map difficult examples to positive ones(1).
gt_label[labels_difficult] = 1
else:
# map difficult examples to negative ones(0).
gt_label[labels_difficult] = 0
gt_label[labels] = 1

info = dict(
Expand Down
3 changes: 2 additions & 1 deletion mmcls/models/heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .conformer_head import ConformerHead
from .deit_head import DeiTClsHead
from .linear_head import LinearClsHead
from .multi_label_csra_head import CSRAClsHead
from .multi_label_head import MultiLabelClsHead
from .multi_label_linear_head import MultiLabelLinearClsHead
from .stacked_head import StackedLinearClsHead
Expand All @@ -11,5 +12,5 @@
__all__ = [
'ClsHead', 'LinearClsHead', 'StackedLinearClsHead', 'MultiLabelClsHead',
'MultiLabelLinearClsHead', 'VisionTransformerClsHead', 'DeiTClsHead',
'ConformerHead'
'ConformerHead', 'CSRAClsHead'
]
121 changes: 121 additions & 0 deletions mmcls/models/heads/multi_label_csra_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/Kevinz-code/CSRA
import torch
import torch.nn as nn
from mmcv.runner import BaseModule, ModuleList

from ..builder import HEADS
from .multi_label_head import MultiLabelClsHead


@HEADS.register_module()
class CSRAClsHead(MultiLabelClsHead):
"""Class-specific residual attention classifier head.
Residual Attention: A Simple but Effective Method for Multi-Label
Recognition (ICCV 2021)
Please refer to the `paper <https://arxiv.org/abs/2108.02456>`__ for
details.
Args:
num_classes (int): Number of categories.
in_channels (int): Number of channels in the input feature map.
num_heads (int): Number of residual at tensor heads.
loss (dict): Config of classification loss.
lam (float): Lambda that combines global average and max pooling
scores.
init_cfg (dict | optional): The extra init config of layers.
Defaults to use dict(type='Normal', layer='Linear', std=0.01).
"""
temperature_settings = { # softmax temperature settings
1: [1],
2: [1, 99],
4: [1, 2, 4, 99],
6: [1, 2, 3, 4, 5, 99],
8: [1, 2, 3, 4, 5, 6, 7, 99]
}

def __init__(self,
num_classes,
in_channels,
num_heads,
lam,
loss=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=1.0),
init_cfg=dict(type='Normal', layer='Linear', std=0.01),
*args,
**kwargs):
assert num_heads in self.temperature_settings.keys(
), 'The num of heads is not in temperature setting.'
assert lam > 0, 'Lambda should be between 0 and 1.'
super(CSRAClsHead, self).__init__(
init_cfg=init_cfg, loss=loss, *args, **kwargs)
self.temp_list = self.temperature_settings[num_heads]
self.csra_heads = ModuleList([
CSRAModule(num_classes, in_channels, self.temp_list[i], lam)
for i in range(num_heads)
])

def pre_logits(self, x):
if isinstance(x, tuple):
x = x[-1]
return x

def simple_test(self, x, post_process=True, **kwargs):
logit = 0.
x = self.pre_logits(x)
for head in self.csra_heads:
logit += head(x)
if post_process:
return self.post_process(logit)
else:
return logit

def forward_train(self, x, gt_label, **kwargs):
logit = 0.
x = self.pre_logits(x)
for head in self.csra_heads:
logit += head(x)
gt_label = gt_label.type_as(logit)
_gt_label = torch.abs(gt_label)
losses = self.loss(logit, _gt_label, **kwargs)
return losses


class CSRAModule(BaseModule):
"""Basic module of CSRA with different temperature.
Args:
num_classes (int): Number of categories.
in_channels (int): Number of channels in the input feature map.
T (int): Temperature setting.
lam (float): Lambda that combines global average and max pooling
scores.
init_cfg (dict | optional): The extra init config of layers.
Defaults to use dict(type='Normal', layer='Linear', std=0.01).
"""

def __init__(self, num_classes, in_channels, T, lam, init_cfg=None):

super(CSRAModule, self).__init__(init_cfg=init_cfg)
self.T = T # temperature
self.lam = lam # Lambda
self.head = nn.Conv2d(in_channels, num_classes, 1, bias=False)
self.softmax = nn.Softmax(dim=2)

def forward(self, x):
score = self.head(x) / torch.norm(
self.head.weight, dim=1, keepdim=True).transpose(0, 1)
score = score.flatten(2)
base_logit = torch.mean(score, dim=2)

if self.T == 99: # max-pooling
att_logit = torch.max(score, dim=2)[0]
else:
score_soft = self.softmax(score * self.T)
att_logit = torch.sum(score * score_soft, dim=2)

return base_logit + self.lam * att_logit
1 change: 1 addition & 0 deletions model-index.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ Import:
- configs/convmixer/metafile.yml
- configs/densenet/metafile.yml
- configs/poolformer/metafile.yml
- configs/csra/metafile.yml
- configs/mvit/metafile.yml
28 changes: 26 additions & 2 deletions tests/test_models/test_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import pytest
import torch

from mmcls.models.heads import (ClsHead, ConformerHead, DeiTClsHead,
LinearClsHead, MultiLabelClsHead,
from mmcls.models.heads import (ClsHead, ConformerHead, CSRAClsHead,
DeiTClsHead, LinearClsHead, MultiLabelClsHead,
MultiLabelLinearClsHead, StackedLinearClsHead,
VisionTransformerClsHead)

Expand Down Expand Up @@ -317,3 +317,27 @@ def test_deit_head():
# test assertion
with pytest.raises(ValueError):
DeiTClsHead(-1, 100)


@pytest.mark.parametrize(
'feat', [torch.rand(4, 20, 20, 30), (torch.rand(4, 20, 20, 30), )])
def test_csra_head(feat):
head = CSRAClsHead(num_classes=10, in_channels=20, num_heads=1, lam=0.1)
fake_gt_label = torch.randint(0, 2, (4, 10))

losses = head.forward_train(feat, fake_gt_label)
assert losses['loss'].item() > 0

# test simple_test with post_process
pred = head.simple_test(feat)
assert isinstance(pred, list) and len(pred) == 4
with patch('torch.onnx.is_in_onnx_export', return_value=True):
pred = head.simple_test(feat)
assert pred.shape == (4, 10)

# test pre_logits
features = head.pre_logits(feat)
if isinstance(feat, tuple):
torch.testing.assert_allclose(features, feat[0])
else:
torch.testing.assert_allclose(features, feat)

0 comments on commit 1a3d51a

Please sign in to comment.