-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Support CSRA head. * Add CSRA config. * Improve training scheduler and Update cfg, ckpt, log * Update metafile * Rename config files and checkpoints Co-authored-by: Ezra-Yu <1105212286@qq.com> Co-authored-by: mzr1996 <mzr1996@163.com>
- Loading branch information
1 parent
b5bb86a
commit 1a3d51a
Showing
8 changed files
with
318 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# CSRA | ||
|
||
> [Residual Attention: A Simple but Effective Method for Multi-Label Recognition](https://arxiv.org/abs/2108.02456) | ||
<!-- [ALGORITHM] --> | ||
|
||
## Abstract | ||
|
||
Multi-label image recognition is a challenging computer vision task of practical use. Progresses in this area, however, are often characterized by complicated methods, heavy computations, and lack of intuitive explanations. To effectively capture different spatial regions occupied by objects from different categories, we propose an embarrassingly simple module, named class-specific residual attention (CSRA). CSRA generates class-specific features for every category by proposing a simple spatial attention score, and then combines it with the class-agnostic average pooling feature. CSRA achieves state-of-the-art results on multilabel recognition, and at the same time is much simpler than them. Furthermore, with only 4 lines of code, CSRA also leads to consistent improvement across many diverse pretrained models and datasets without any extra training. CSRA is both easy to implement and light in computations, which also enjoys intuitive explanations and visualizations. | ||
|
||
<div align=center> | ||
<img src="https://user-images.githubusercontent.com/84259897/176982245-3ffcff56-a4ea-4474-9967-bc2b612bbaa3.png" width="80%"/> | ||
</div> | ||
|
||
## Results and models | ||
|
||
### VOC2007 | ||
|
||
| Model | Pretrain | Params(M) | Flops(G) | mAP | OF1 (%) | CF1 (%) | Config | Download | | ||
| :------------: | :------------------------------------------------: | :-------: | :------: | :---: | :-----: | :-----: | :-----------------------------------------------: | :-------------------------------------------------: | | ||
| Resnet101-CSRA | [ImageNet-1k](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth) | 23.55 | 4.12 | 94.98 | 90.80 | 89.16 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/csra/resnet101-csra_1xb16_voc07-448px.py) | [model](https://download.openmmlab.com/mmclassification/v0/csra/resnet101-csra_1xb16_voc07-448px_20220722-29efb40a.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/csra/resnet101-csra_1xb16_voc07-448px_20220722-29efb40a.log.json) | | ||
|
||
## Citation | ||
|
||
```bibtex | ||
@misc{https://doi.org/10.48550/arxiv.2108.02456, | ||
doi = {10.48550/ARXIV.2108.02456}, | ||
url = {https://arxiv.org/abs/2108.02456}, | ||
author = {Zhu, Ke and Wu, Jianxin}, | ||
keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences}, | ||
title = {Residual Attention: A Simple but Effective Method for Multi-Label Recognition}, | ||
publisher = {arXiv}, | ||
year = {2021}, | ||
copyright = {arXiv.org perpetual, non-exclusive license} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
Collections: | ||
- Name: CSRA | ||
Metadata: | ||
Training Data: PASCAL VOC 2007 | ||
Architecture: | ||
- Class-specific Residual Attention | ||
Paper: | ||
URL: https://arxiv.org/abs/1911.11929 | ||
Title: 'Residual Attention: A Simple but Effective Method for Multi-Label Recognition' | ||
README: configs/csra/README.md | ||
Code: | ||
Version: v0.24.0 | ||
URL: https://github.com/open-mmlab/mmclassification/blob/v0.24.0/mmcls/models/heads/multi_label_csra_head.py | ||
|
||
Models: | ||
- Name: resnet101-csra_1xb16_voc07-448px | ||
Metadata: | ||
FLOPs: 4120000000 | ||
Parameters: 23550000 | ||
In Collections: CSRA | ||
Results: | ||
- Dataset: PASCAL VOC 2007 | ||
Metrics: | ||
mAP: 94.98 | ||
OF1: 90.80 | ||
CF1: 89.16 | ||
Task: Multi-Label Classification | ||
Weights: https://download.openmmlab.com/mmclassification/v0/csra/resnet101-csra_1xb16_voc07-448px_20220722-29efb40a.pth | ||
Config: configs/csra/resnet101-csra_1xb16_voc07-448px.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
_base_ = ['../_base_/datasets/voc_bs16.py', '../_base_/default_runtime.py'] | ||
|
||
# Pre-trained Checkpoint Path | ||
checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth' # noqa | ||
# If you want to use the pre-trained weight of ResNet101-CutMix from | ||
# the originary repo(https://github.com/Kevinz-code/CSRA). Script of | ||
# 'tools/convert_models/torchvision_to_mmcls.py' can help you convert weight | ||
# into mmcls format. The mAP result would hit 95.5 by using the weight. | ||
# checkpoint = 'PATH/TO/PRE-TRAINED_WEIGHT' | ||
|
||
# model settings | ||
model = dict( | ||
type='ImageClassifier', | ||
backbone=dict( | ||
type='ResNet', | ||
depth=101, | ||
num_stages=4, | ||
out_indices=(3, ), | ||
style='pytorch', | ||
init_cfg=dict( | ||
type='Pretrained', checkpoint=checkpoint, prefix='backbone')), | ||
neck=None, | ||
head=dict( | ||
type='CSRAClsHead', | ||
num_classes=20, | ||
in_channels=2048, | ||
num_heads=1, | ||
lam=0.1, | ||
loss=dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0))) | ||
|
||
# dataset setting | ||
img_norm_cfg = dict(mean=[0, 0, 0], std=[255, 255, 255], to_rgb=True) | ||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='RandomResizedCrop', size=448, scale=(0.7, 1.0)), | ||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='ImageToTensor', keys=['img']), | ||
dict(type='ToTensor', keys=['gt_label']), | ||
dict(type='Collect', keys=['img', 'gt_label']) | ||
] | ||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='Resize', size=448), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='ImageToTensor', keys=['img']), | ||
dict(type='Collect', keys=['img']) | ||
] | ||
data = dict( | ||
# map the difficult examples as negative ones(0) | ||
train=dict(pipeline=train_pipeline, difficult_as_postive=False), | ||
val=dict(pipeline=test_pipeline), | ||
test=dict(pipeline=test_pipeline)) | ||
|
||
# optimizer | ||
# the lr of classifier.head is 10 * base_lr, which help convergence. | ||
optimizer = dict( | ||
type='SGD', | ||
lr=0.0002, | ||
momentum=0.9, | ||
weight_decay=0.0001, | ||
paramwise_cfg=dict(custom_keys={'head': dict(lr_mult=10)})) | ||
|
||
optimizer_config = dict(grad_clip=None) | ||
|
||
# learning policy | ||
lr_config = dict( | ||
policy='step', | ||
step=6, | ||
gamma=0.1, | ||
warmup='linear', | ||
warmup_iters=1, | ||
warmup_ratio=1e-7, | ||
warmup_by_epoch=True) | ||
runner = dict(type='EpochBasedRunner', max_epochs=20) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
# Modified from https://github.com/Kevinz-code/CSRA | ||
import torch | ||
import torch.nn as nn | ||
from mmcv.runner import BaseModule, ModuleList | ||
|
||
from ..builder import HEADS | ||
from .multi_label_head import MultiLabelClsHead | ||
|
||
|
||
@HEADS.register_module() | ||
class CSRAClsHead(MultiLabelClsHead): | ||
"""Class-specific residual attention classifier head. | ||
Residual Attention: A Simple but Effective Method for Multi-Label | ||
Recognition (ICCV 2021) | ||
Please refer to the `paper <https://arxiv.org/abs/2108.02456>`__ for | ||
details. | ||
Args: | ||
num_classes (int): Number of categories. | ||
in_channels (int): Number of channels in the input feature map. | ||
num_heads (int): Number of residual at tensor heads. | ||
loss (dict): Config of classification loss. | ||
lam (float): Lambda that combines global average and max pooling | ||
scores. | ||
init_cfg (dict | optional): The extra init config of layers. | ||
Defaults to use dict(type='Normal', layer='Linear', std=0.01). | ||
""" | ||
temperature_settings = { # softmax temperature settings | ||
1: [1], | ||
2: [1, 99], | ||
4: [1, 2, 4, 99], | ||
6: [1, 2, 3, 4, 5, 99], | ||
8: [1, 2, 3, 4, 5, 6, 7, 99] | ||
} | ||
|
||
def __init__(self, | ||
num_classes, | ||
in_channels, | ||
num_heads, | ||
lam, | ||
loss=dict( | ||
type='CrossEntropyLoss', | ||
use_sigmoid=True, | ||
reduction='mean', | ||
loss_weight=1.0), | ||
init_cfg=dict(type='Normal', layer='Linear', std=0.01), | ||
*args, | ||
**kwargs): | ||
assert num_heads in self.temperature_settings.keys( | ||
), 'The num of heads is not in temperature setting.' | ||
assert lam > 0, 'Lambda should be between 0 and 1.' | ||
super(CSRAClsHead, self).__init__( | ||
init_cfg=init_cfg, loss=loss, *args, **kwargs) | ||
self.temp_list = self.temperature_settings[num_heads] | ||
self.csra_heads = ModuleList([ | ||
CSRAModule(num_classes, in_channels, self.temp_list[i], lam) | ||
for i in range(num_heads) | ||
]) | ||
|
||
def pre_logits(self, x): | ||
if isinstance(x, tuple): | ||
x = x[-1] | ||
return x | ||
|
||
def simple_test(self, x, post_process=True, **kwargs): | ||
logit = 0. | ||
x = self.pre_logits(x) | ||
for head in self.csra_heads: | ||
logit += head(x) | ||
if post_process: | ||
return self.post_process(logit) | ||
else: | ||
return logit | ||
|
||
def forward_train(self, x, gt_label, **kwargs): | ||
logit = 0. | ||
x = self.pre_logits(x) | ||
for head in self.csra_heads: | ||
logit += head(x) | ||
gt_label = gt_label.type_as(logit) | ||
_gt_label = torch.abs(gt_label) | ||
losses = self.loss(logit, _gt_label, **kwargs) | ||
return losses | ||
|
||
|
||
class CSRAModule(BaseModule): | ||
"""Basic module of CSRA with different temperature. | ||
Args: | ||
num_classes (int): Number of categories. | ||
in_channels (int): Number of channels in the input feature map. | ||
T (int): Temperature setting. | ||
lam (float): Lambda that combines global average and max pooling | ||
scores. | ||
init_cfg (dict | optional): The extra init config of layers. | ||
Defaults to use dict(type='Normal', layer='Linear', std=0.01). | ||
""" | ||
|
||
def __init__(self, num_classes, in_channels, T, lam, init_cfg=None): | ||
|
||
super(CSRAModule, self).__init__(init_cfg=init_cfg) | ||
self.T = T # temperature | ||
self.lam = lam # Lambda | ||
self.head = nn.Conv2d(in_channels, num_classes, 1, bias=False) | ||
self.softmax = nn.Softmax(dim=2) | ||
|
||
def forward(self, x): | ||
score = self.head(x) / torch.norm( | ||
self.head.weight, dim=1, keepdim=True).transpose(0, 1) | ||
score = score.flatten(2) | ||
base_logit = torch.mean(score, dim=2) | ||
|
||
if self.T == 99: # max-pooling | ||
att_logit = torch.max(score, dim=2)[0] | ||
else: | ||
score_soft = self.softmax(score * self.T) | ||
att_logit = torch.sum(score * score_soft, dim=2) | ||
|
||
return base_logit + self.lam * att_logit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters