-
Notifications
You must be signed in to change notification settings - Fork 228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CodeCamp #122] Support KD algorithm MGD for detection. #377
Changes from 1 commit
d070987
5c54914
a9c9c05
f391a59
522e6f7
c96c068
f221c20
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# MGD | ||
|
||
> [Masked Generative Distillation](https://arxiv.org/abs/2205.01529) | ||
|
||
<!-- [ALGORITHM] --> | ||
|
||
## Abstract | ||
|
||
Knowledge distillation has been applied to various tasks successfully. The current distillation algorithm usually improves students' performance by imitating the output of the teacher. This paper shows that teachers can also improve students' representation power by guiding students' feature recovery. From this point of view, we propose Masked Generative Distillation (MGD), which is simple: we mask random pixels of the student's feature and force it to generate the teacher's full feature through a simple block. MGD is a truly general feature-based distillation method, which can be utilized on various tasks, including image classification, object detection, semantic segmentation and instance segmentation. We experiment on different models with extensive datasets and the results show that all the students achieve excellent improvements. Notably, we boost ResNet-18 from 69.90% to 71.69% ImageNet top-1 accuracy, RetinaNet with ResNet-50 backbone from 37.4 to 41.0 Boundingbox mAP, SOLO based on ResNet-50 from 33.1 to 36.2 Mask mAP and DeepLabV3 based on ResNet-18 from 73.20 to 76.02 mIoU. | ||
|
||
![pipeline](https://github.com/yzd-v/MGD/raw/master/architecture.png) | ||
|
||
## Results and models | ||
|
||
### Detection | ||
|
||
| Location | Dataset | Teacher | Student | Lr schd | mAP | mAP(T) | mAP(S) | Config | Download | | ||
| :------: | :-----: | :----------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------: | :-----: | :-: | :----: | :----: | :-------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ||
| FPN | COCO | [RetinaNet-X101](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/retinanet/retinanet_x101-64x4d_fpn_1x_coco.py) | [RetinaNet-R50](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/retinanet/retinanet_r50_fpn_2x_coco.py) | 2x | | 41.0 | 37.4 | [config](mgd_fpn_retina_x101_retina_r50_2x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth) \|[model](<>) \| [log](<>) | | ||
|
||
## Citation | ||
|
||
```latex | ||
@article{yang2022masked, | ||
title={Masked Generative Distillation}, | ||
author={Yang, Zhendong and Li, Zhe and Shao, Mingqi and Shi, Dachuan and Yuan, Zehuan and Yuan, Chun}, | ||
journal={arXiv preprint arXiv:2205.01529}, | ||
year={2022} | ||
} | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
_base_ = ['mmdet::retinanet/retinanet_r50_fpn_2x_coco.py'] | ||
|
||
teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth' # noqa: E501 | ||
|
||
student = _base_.model | ||
student.neck.init_cfg = dict( | ||
type='Pretrained', prefix='neck.', checkpoint=teacher_ckpt) | ||
student.bbox_head.init_cfg = dict( | ||
type='Pretrained', prefix='bbox_head.', checkpoint=teacher_ckpt) | ||
|
||
model = dict( | ||
_scope_='mmrazor', | ||
_delete_=True, | ||
type='FpnTeacherDistill', | ||
architecture=student, | ||
teacher=dict( | ||
cfg_path='mmdet::retinanet/retinanet_x101-64x4d_fpn_1x_coco.py', | ||
pretrained=False), | ||
teacher_ckpt=teacher_ckpt, | ||
distiller=dict( | ||
type='ConfigurableDistiller', | ||
student_recorders=dict( | ||
fpn0=dict(type='ModuleOutputs', source='neck.fpn_convs.0.conv'), | ||
fpn1=dict(type='ModuleOutputs', source='neck.fpn_convs.1.conv'), | ||
fpn2=dict(type='ModuleOutputs', source='neck.fpn_convs.2.conv'), | ||
fpn3=dict(type='ModuleOutputs', source='neck.fpn_convs.3.conv'), | ||
fpn4=dict(type='ModuleOutputs', source='neck.fpn_convs.4.conv')), | ||
teacher_recorders=dict( | ||
fpn0=dict(type='ModuleOutputs', source='neck.fpn_convs.0.conv'), | ||
fpn1=dict(type='ModuleOutputs', source='neck.fpn_convs.1.conv'), | ||
fpn2=dict(type='ModuleOutputs', source='neck.fpn_convs.2.conv'), | ||
fpn3=dict(type='ModuleOutputs', source='neck.fpn_convs.3.conv'), | ||
fpn4=dict(type='ModuleOutputs', source='neck.fpn_convs.4.conv')), | ||
distill_losses=dict( | ||
loss_mgd_fpn0=dict( | ||
type='MGDLoss', | ||
student_channels=256, | ||
teacher_channels=256, | ||
alpha_mgd=0.00002, | ||
lambda_mgd=0.65), | ||
loss_mgd_fpn1=dict( | ||
type='MGDLoss', | ||
student_channels=256, | ||
teacher_channels=256, | ||
alpha_mgd=0.00002, | ||
lambda_mgd=0.65), | ||
loss_mgd_fpn2=dict( | ||
type='MGDLoss', | ||
student_channels=256, | ||
teacher_channels=256, | ||
alpha_mgd=0.00002, | ||
lambda_mgd=0.65), | ||
loss_mgd_fpn3=dict( | ||
type='MGDLoss', | ||
student_channels=256, | ||
teacher_channels=256, | ||
alpha_mgd=0.00002, | ||
lambda_mgd=0.65), | ||
loss_mgd_fpn4=dict( | ||
type='MGDLoss', | ||
student_channels=256, | ||
teacher_channels=256, | ||
alpha_mgd=0.00002, | ||
lambda_mgd=0.65)), | ||
loss_forward_mappings=dict( | ||
loss_mgd_fpn0=dict( | ||
preds_S=dict(from_student=True, recorder='fpn0'), | ||
preds_T=dict(from_student=False, recorder='fpn0')), | ||
loss_mgd_fpn1=dict( | ||
preds_S=dict(from_student=True, recorder='fpn1'), | ||
preds_T=dict(from_student=False, recorder='fpn1')), | ||
loss_mgd_fpn2=dict( | ||
preds_S=dict(from_student=True, recorder='fpn2'), | ||
preds_T=dict(from_student=False, recorder='fpn2')), | ||
loss_mgd_fpn3=dict( | ||
preds_S=dict(from_student=True, recorder='fpn3'), | ||
preds_T=dict(from_student=False, recorder='fpn3')), | ||
loss_mgd_fpn4=dict( | ||
preds_S=dict(from_student=True, recorder='fpn4'), | ||
preds_T=dict(from_student=False, recorder='fpn4'))))) | ||
|
||
find_unused_parameters = True | ||
|
||
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop') | ||
|
||
optimizer_config = dict( | ||
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) | ||
|
||
param_scheduler = [ | ||
dict( | ||
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500), | ||
dict( | ||
type='MultiStepLR', | ||
begin=0, | ||
end=24, | ||
by_epoch=True, | ||
milestones=[16, 22], | ||
gamma=0.1) | ||
] | ||
|
||
optim_wrapper = dict( | ||
optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from mmrazor.registry import MODELS | ||
|
||
|
||
@MODELS.register_module() | ||
class MGDLoss(nn.Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add the test case of |
||
"""PyTorch version of `Masked Generative Distillation. | ||
|
||
<https://arxiv.org/abs/2205.01529>` | ||
|
||
Args: | ||
student_channels(int): Number of channels in the student's feature map. | ||
teacher_channels(int): Number of channels in the teacher's feature map. | ||
name (str): the loss name of the layer | ||
alpha_mgd (float, optional): Weight of dis_loss. Defaults to 0.00002 | ||
lambda_mgd (float, optional): masked ratio. Defaults to 0.65 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
student_channels, | ||
teacher_channels, | ||
alpha_mgd=0.00002, | ||
lambda_mgd=0.65, | ||
): | ||
super(MGDLoss, self).__init__() | ||
self.alpha_mgd = alpha_mgd | ||
self.lambda_mgd = lambda_mgd | ||
|
||
if student_channels != teacher_channels: | ||
self.align = nn.Conv2d( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We tend you to move the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I will change it later. |
||
student_channels, | ||
teacher_channels, | ||
kernel_size=1, | ||
stride=1, | ||
padding=0) | ||
else: | ||
self.align = None | ||
|
||
self.generation = nn.Sequential( | ||
nn.Conv2d( | ||
teacher_channels, teacher_channels, kernel_size=3, padding=1), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d( | ||
teacher_channels, teacher_channels, kernel_size=3, padding=1)) | ||
|
||
def forward(self, preds_S, preds_T): | ||
"""Forward function. | ||
|
||
Args: | ||
preds_S(Tensor): Bs*C*H*W, student's feature map | ||
preds_T(Tensor): Bs*C*H*W, teacher's feature map | ||
""" | ||
assert preds_S.shape[-2:] == preds_T.shape[-2:] | ||
|
||
if self.align is not None: | ||
preds_S = self.align(preds_S) | ||
|
||
loss = self.get_dis_loss(preds_S, preds_T) * self.alpha_mgd | ||
|
||
return loss | ||
|
||
def get_dis_loss(self, preds_S, preds_T): | ||
loss_mse = nn.MSELoss(reduction='sum') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Build the |
||
N, C, H, W = preds_T.shape | ||
|
||
device = preds_S.device | ||
mat = torch.rand((N, 1, H, W)).to(device) | ||
mat = torch.where(mat > 1 - self.lambda_mgd, 0, 1).to(device) | ||
|
||
masked_fea = torch.mul(preds_S, mat) | ||
new_fea = self.generation(masked_fea) | ||
|
||
dis_loss = loss_mse(new_fea, preds_T) / N | ||
|
||
return dis_loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please contact @pppppM to upload the relative
log
andmodel
files.