Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support Label Assignment Distillation (LAD) #6342

Merged
merged 11 commits into from
Nov 24, 2021
32 changes: 32 additions & 0 deletions configs/lad/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Improving Object Detection by Label Assignment Distillation

<!-- [ALGORITHM] -->

```latex
@inproceedings{nguyen2021improving,
title={Improving Object Detection by Label Assignment Distillation},
author={Chuong H. Nguyen and Thuy C. Nguyen and Tuan N. Tang and Nam L. H. Phan},
booktitle = {WACV},
year={2022}
}
```

## Results and Models

We provide config files to reproduce the object detection results in the
WACV 2022 paper for Improving Object Detection by Label Assignment
Distillation.

### PAA with LAD

| Teacher | Student | Training schedule | AP (val) | Config |
| :-------: | :-----: | :---------------: | :------: | :----------------------------------------------------: |
| -- | R-50 | 1x | 40.4 | |
| -- | R-101 | 1x | 42.6 | |
| R-101 | R-50 | 1x | 41.6 | [config](configs/lad/lad_r50_paa_r101_fpn_coco_1x.py) |
| R-50 | R-101 | 1x | 43.2 | [config](configs/lad/lad_r101_paa_r50_fpn_coco_1x.py) |

## Note

- Meaning of Config name: lad_r50(student model)_paa(based on paa)_r101(teacher model)_fpn(neck)_coco(dataset)_1x(12 epoch).py
- Results may fluctuate by about 0.2 mAP.
121 changes: 121 additions & 0 deletions configs/lad/lad_r101_paa_r50_fpn_coco_1x.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/paa/paa_r50_fpn_1x_coco/paa_r50_fpn_1x_coco_20200821-936edec3.pth' # noqa
model = dict(
type='LAD',
# student
backbone=dict(
type='ResNet',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained',
checkpoint='torchvision://resnet101')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=5),
bbox_head=dict(
type='LADHead',
reg_decoded_bbox=True,
score_voting=True,
topk=9,
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=1.3),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)),
# teacher
teacher_ckpt=teacher_ckpt,
teacher_backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
teacher_neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=5),
teacher_bbox_head=dict(
type='LADHead',
reg_decoded_bbox=True,
score_voting=True,
topk=9,
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=1.3),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)),
# training and testing settings
train_cfg=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.1,
neg_iou_thr=0.1,
min_pos_iou=0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
score_voting=True,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100))
data = dict(samples_per_gpu=8, workers_per_gpu=4)
optimizer = dict(lr=0.01)
fp16 = dict(loss_scale=512.)
120 changes: 120 additions & 0 deletions configs/lad/lad_r50_paa_r101_fpn_coco_1x.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
teacher_ckpt = 'http://download.openmmlab.com/mmdetection/v2.0/paa/paa_r101_fpn_1x_coco/paa_r101_fpn_1x_coco_20200821-0a1825a4.pth' # noqa
model = dict(
type='LAD',
# student
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=5),
bbox_head=dict(
type='LADHead',
reg_decoded_bbox=True,
score_voting=True,
topk=9,
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=1.3),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)),
# teacher
teacher_ckpt=teacher_ckpt,
teacher_backbone=dict(
type='ResNet',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
teacher_neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=5),
teacher_bbox_head=dict(
type='LADHead',
reg_decoded_bbox=True,
score_voting=True,
topk=9,
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=1.3),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)),
# training and testing settings
train_cfg=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.1,
neg_iou_thr=0.1,
min_pos_iou=0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
score_voting=True,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100))
data = dict(samples_per_gpu=8, workers_per_gpu=4)
optimizer = dict(lr=0.01)
fp16 = dict(loss_scale=512.)
3 changes: 2 additions & 1 deletion mmdet/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .ga_rpn_head import GARPNHead
from .gfl_head import GFLHead
from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
from .lad_head import LADHead
from .ld_head import LDHead
from .nasfcos_head import NASFCOSHead
from .paa_head import PAAHead
Expand Down Expand Up @@ -47,5 +48,5 @@
'CascadeRPNHead', 'EmbeddingRPNHead', 'LDHead', 'CascadeRPNHead',
'AutoAssignHead', 'DETRHead', 'YOLOFHead', 'DeformableDETRHead',
'SOLOHead', 'DecoupledSOLOHead', 'CenterNetHead', 'YOLOXHead',
'DecoupledSOLOLightHead'
'DecoupledSOLOLightHead', 'LADHead'
]
Loading