-
Notifications
You must be signed in to change notification settings - Fork 9.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Support Label Assignment Distillation (LAD) (#6342)
* add LAD * inherit LAD from KnowledgeDistillationSingleStageDetector * add configs/lad/lad_r101_paa_r50_fpn_coco_1x.py * update LAD readme * update configs/lad/README.md * try not to use abbreviations for variable names * add unittest for lad_head * update test_lad_head * remove main in tests/test_models/test_dense_heads/test_lad_head.py
- Loading branch information
Showing
8 changed files
with
748 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Improving Object Detection by Label Assignment Distillation | ||
|
||
<!-- [ALGORITHM] --> | ||
|
||
```latex | ||
@inproceedings{nguyen2021improving, | ||
title={Improving Object Detection by Label Assignment Distillation}, | ||
author={Chuong H. Nguyen and Thuy C. Nguyen and Tuan N. Tang and Nam L. H. Phan}, | ||
booktitle = {WACV}, | ||
year={2022} | ||
} | ||
``` | ||
|
||
## Results and Models | ||
|
||
We provide config files to reproduce the object detection results in the | ||
WACV 2022 paper for Improving Object Detection by Label Assignment | ||
Distillation. | ||
|
||
### PAA with LAD | ||
|
||
| Teacher | Student | Training schedule | AP (val) | Config | | ||
| :-------: | :-----: | :---------------: | :------: | :----------------------------------------------------: | | ||
| -- | R-50 | 1x | 40.4 | | | ||
| -- | R-101 | 1x | 42.6 | | | ||
| R-101 | R-50 | 1x | 41.6 | [config](configs/lad/lad_r50_paa_r101_fpn_coco_1x.py) | | ||
| R-50 | R-101 | 1x | 43.2 | [config](configs/lad/lad_r101_paa_r50_fpn_coco_1x.py) | | ||
|
||
## Note | ||
|
||
- Meaning of Config name: lad_r50(student model)_paa(based on paa)_r101(teacher model)_fpn(neck)_coco(dataset)_1x(12 epoch).py | ||
- Results may fluctuate by about 0.2 mAP. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
_base_ = [ | ||
'../_base_/datasets/coco_detection.py', | ||
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' | ||
] | ||
teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/paa/paa_r50_fpn_1x_coco/paa_r50_fpn_1x_coco_20200821-936edec3.pth' # noqa | ||
model = dict( | ||
type='LAD', | ||
# student | ||
backbone=dict( | ||
type='ResNet', | ||
depth=101, | ||
num_stages=4, | ||
out_indices=(0, 1, 2, 3), | ||
frozen_stages=1, | ||
norm_cfg=dict(type='BN', requires_grad=True), | ||
norm_eval=True, | ||
style='pytorch', | ||
init_cfg=dict(type='Pretrained', | ||
checkpoint='torchvision://resnet101')), | ||
neck=dict( | ||
type='FPN', | ||
in_channels=[256, 512, 1024, 2048], | ||
out_channels=256, | ||
start_level=1, | ||
add_extra_convs='on_output', | ||
num_outs=5), | ||
bbox_head=dict( | ||
type='LADHead', | ||
reg_decoded_bbox=True, | ||
score_voting=True, | ||
topk=9, | ||
num_classes=80, | ||
in_channels=256, | ||
stacked_convs=4, | ||
feat_channels=256, | ||
anchor_generator=dict( | ||
type='AnchorGenerator', | ||
ratios=[1.0], | ||
octave_base_scale=8, | ||
scales_per_octave=1, | ||
strides=[8, 16, 32, 64, 128]), | ||
bbox_coder=dict( | ||
type='DeltaXYWHBBoxCoder', | ||
target_means=[.0, .0, .0, .0], | ||
target_stds=[0.1, 0.1, 0.2, 0.2]), | ||
loss_cls=dict( | ||
type='FocalLoss', | ||
use_sigmoid=True, | ||
gamma=2.0, | ||
alpha=0.25, | ||
loss_weight=1.0), | ||
loss_bbox=dict(type='GIoULoss', loss_weight=1.3), | ||
loss_centerness=dict( | ||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)), | ||
# teacher | ||
teacher_ckpt=teacher_ckpt, | ||
teacher_backbone=dict( | ||
type='ResNet', | ||
depth=50, | ||
num_stages=4, | ||
out_indices=(0, 1, 2, 3), | ||
frozen_stages=1, | ||
norm_cfg=dict(type='BN', requires_grad=True), | ||
norm_eval=True, | ||
style='pytorch'), | ||
teacher_neck=dict( | ||
type='FPN', | ||
in_channels=[256, 512, 1024, 2048], | ||
out_channels=256, | ||
start_level=1, | ||
add_extra_convs='on_output', | ||
num_outs=5), | ||
teacher_bbox_head=dict( | ||
type='LADHead', | ||
reg_decoded_bbox=True, | ||
score_voting=True, | ||
topk=9, | ||
num_classes=80, | ||
in_channels=256, | ||
stacked_convs=4, | ||
feat_channels=256, | ||
anchor_generator=dict( | ||
type='AnchorGenerator', | ||
ratios=[1.0], | ||
octave_base_scale=8, | ||
scales_per_octave=1, | ||
strides=[8, 16, 32, 64, 128]), | ||
bbox_coder=dict( | ||
type='DeltaXYWHBBoxCoder', | ||
target_means=[.0, .0, .0, .0], | ||
target_stds=[0.1, 0.1, 0.2, 0.2]), | ||
loss_cls=dict( | ||
type='FocalLoss', | ||
use_sigmoid=True, | ||
gamma=2.0, | ||
alpha=0.25, | ||
loss_weight=1.0), | ||
loss_bbox=dict(type='GIoULoss', loss_weight=1.3), | ||
loss_centerness=dict( | ||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)), | ||
# training and testing settings | ||
train_cfg=dict( | ||
assigner=dict( | ||
type='MaxIoUAssigner', | ||
pos_iou_thr=0.1, | ||
neg_iou_thr=0.1, | ||
min_pos_iou=0, | ||
ignore_iof_thr=-1), | ||
allowed_border=-1, | ||
pos_weight=-1, | ||
debug=False), | ||
test_cfg=dict( | ||
nms_pre=1000, | ||
min_bbox_size=0, | ||
score_thr=0.05, | ||
score_voting=True, | ||
nms=dict(type='nms', iou_threshold=0.6), | ||
max_per_img=100)) | ||
data = dict(samples_per_gpu=8, workers_per_gpu=4) | ||
optimizer = dict(lr=0.01) | ||
fp16 = dict(loss_scale=512.) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
_base_ = [ | ||
'../_base_/datasets/coco_detection.py', | ||
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' | ||
] | ||
teacher_ckpt = 'http://download.openmmlab.com/mmdetection/v2.0/paa/paa_r101_fpn_1x_coco/paa_r101_fpn_1x_coco_20200821-0a1825a4.pth' # noqa | ||
model = dict( | ||
type='LAD', | ||
# student | ||
backbone=dict( | ||
type='ResNet', | ||
depth=50, | ||
num_stages=4, | ||
out_indices=(0, 1, 2, 3), | ||
frozen_stages=1, | ||
norm_cfg=dict(type='BN', requires_grad=True), | ||
norm_eval=True, | ||
style='pytorch', | ||
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), | ||
neck=dict( | ||
type='FPN', | ||
in_channels=[256, 512, 1024, 2048], | ||
out_channels=256, | ||
start_level=1, | ||
add_extra_convs='on_output', | ||
num_outs=5), | ||
bbox_head=dict( | ||
type='LADHead', | ||
reg_decoded_bbox=True, | ||
score_voting=True, | ||
topk=9, | ||
num_classes=80, | ||
in_channels=256, | ||
stacked_convs=4, | ||
feat_channels=256, | ||
anchor_generator=dict( | ||
type='AnchorGenerator', | ||
ratios=[1.0], | ||
octave_base_scale=8, | ||
scales_per_octave=1, | ||
strides=[8, 16, 32, 64, 128]), | ||
bbox_coder=dict( | ||
type='DeltaXYWHBBoxCoder', | ||
target_means=[.0, .0, .0, .0], | ||
target_stds=[0.1, 0.1, 0.2, 0.2]), | ||
loss_cls=dict( | ||
type='FocalLoss', | ||
use_sigmoid=True, | ||
gamma=2.0, | ||
alpha=0.25, | ||
loss_weight=1.0), | ||
loss_bbox=dict(type='GIoULoss', loss_weight=1.3), | ||
loss_centerness=dict( | ||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)), | ||
# teacher | ||
teacher_ckpt=teacher_ckpt, | ||
teacher_backbone=dict( | ||
type='ResNet', | ||
depth=101, | ||
num_stages=4, | ||
out_indices=(0, 1, 2, 3), | ||
frozen_stages=1, | ||
norm_cfg=dict(type='BN', requires_grad=True), | ||
norm_eval=True, | ||
style='pytorch'), | ||
teacher_neck=dict( | ||
type='FPN', | ||
in_channels=[256, 512, 1024, 2048], | ||
out_channels=256, | ||
start_level=1, | ||
add_extra_convs='on_output', | ||
num_outs=5), | ||
teacher_bbox_head=dict( | ||
type='LADHead', | ||
reg_decoded_bbox=True, | ||
score_voting=True, | ||
topk=9, | ||
num_classes=80, | ||
in_channels=256, | ||
stacked_convs=4, | ||
feat_channels=256, | ||
anchor_generator=dict( | ||
type='AnchorGenerator', | ||
ratios=[1.0], | ||
octave_base_scale=8, | ||
scales_per_octave=1, | ||
strides=[8, 16, 32, 64, 128]), | ||
bbox_coder=dict( | ||
type='DeltaXYWHBBoxCoder', | ||
target_means=[.0, .0, .0, .0], | ||
target_stds=[0.1, 0.1, 0.2, 0.2]), | ||
loss_cls=dict( | ||
type='FocalLoss', | ||
use_sigmoid=True, | ||
gamma=2.0, | ||
alpha=0.25, | ||
loss_weight=1.0), | ||
loss_bbox=dict(type='GIoULoss', loss_weight=1.3), | ||
loss_centerness=dict( | ||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)), | ||
# training and testing settings | ||
train_cfg=dict( | ||
assigner=dict( | ||
type='MaxIoUAssigner', | ||
pos_iou_thr=0.1, | ||
neg_iou_thr=0.1, | ||
min_pos_iou=0, | ||
ignore_iof_thr=-1), | ||
allowed_border=-1, | ||
pos_weight=-1, | ||
debug=False), | ||
test_cfg=dict( | ||
nms_pre=1000, | ||
min_bbox_size=0, | ||
score_thr=0.05, | ||
score_voting=True, | ||
nms=dict(type='nms', iou_threshold=0.6), | ||
max_per_img=100)) | ||
data = dict(samples_per_gpu=8, workers_per_gpu=4) | ||
optimizer = dict(lr=0.01) | ||
fp16 = dict(loss_scale=512.) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.