-
Notifications
You must be signed in to change notification settings - Fork 228
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
* [Feature] Support KD algorithm MGD for detection. * use connector to beauty mgd. * fix typo, add unitest. * fix mgd loss unitest. * fix mgd connector unitest. * add model pth and log file. * add mAP.
- Loading branch information
1 parent
bcd6878
commit 5ebf839
Showing
8 changed files
with
294 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# MGD | ||
|
||
> [Masked Generative Distillation](https://arxiv.org/abs/2205.01529) | ||
<!-- [ALGORITHM] --> | ||
|
||
## Abstract | ||
|
||
Knowledge distillation has been applied to various tasks successfully. The current distillation algorithm usually improves students' performance by imitating the output of the teacher. This paper shows that teachers can also improve students' representation power by guiding students' feature recovery. From this point of view, we propose Masked Generative Distillation (MGD), which is simple: we mask random pixels of the student's feature and force it to generate the teacher's full feature through a simple block. MGD is a truly general feature-based distillation method, which can be utilized on various tasks, including image classification, object detection, semantic segmentation and instance segmentation. We experiment on different models with extensive datasets and the results show that all the students achieve excellent improvements. Notably, we boost ResNet-18 from 69.90% to 71.69% ImageNet top-1 accuracy, RetinaNet with ResNet-50 backbone from 37.4 to 41.0 Boundingbox mAP, SOLO based on ResNet-50 from 33.1 to 36.2 Mask mAP and DeepLabV3 based on ResNet-18 from 73.20 to 76.02 mIoU. | ||
|
||
![pipeline](https://github.com/yzd-v/MGD/raw/master/architecture.png) | ||
|
||
## Results and models | ||
|
||
### Detection | ||
|
||
| Location | Dataset | Teacher | Student | Lr schd | mAP | mAP(T) | mAP(S) | Config | Download | | ||
| :------: | :-----: | :----------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------: | :-----: | :--: | :----: | :----: | :-------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ||
| FPN | COCO | [RetinaNet-X101](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/retinanet/retinanet_x101-64x4d_fpn_1x_coco.py) | [RetinaNet-R50](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/retinanet/retinanet_r50_fpn_2x_coco.py) | 2x | 41.0 | 41.0 | 37.4 | [config](mgd_fpn_retina_x101_retina_r50_2x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/mgd/mgd_fpn_retina_x101_retina_r50_2x_coco_20221209_191847-87141529.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/mgd/mgd_fpn_retina_x101_retina_r50_2x_coco_20221209_191847-87141529.log) | | ||
|
||
## Citation | ||
|
||
```latex | ||
@article{yang2022masked, | ||
title={Masked Generative Distillation}, | ||
author={Yang, Zhendong and Li, Zhe and Shao, Mingqi and Shi, Dachuan and Yuan, Zehuan and Yuan, Chun}, | ||
journal={arXiv preprint arXiv:2205.01529}, | ||
year={2022} | ||
} | ||
``` |
118 changes: 118 additions & 0 deletions
118
configs/distill/mmdet/mgd/mgd_fpn_retina_x101_retina_r50_2x_coco.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
_base_ = ['mmdet::retinanet/retinanet_r50_fpn_2x_coco.py'] | ||
|
||
teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth' # noqa: E501 | ||
|
||
student = _base_.model | ||
student.neck.init_cfg = dict( | ||
type='Pretrained', prefix='neck.', checkpoint=teacher_ckpt) | ||
student.bbox_head.init_cfg = dict( | ||
type='Pretrained', prefix='bbox_head.', checkpoint=teacher_ckpt) | ||
|
||
model = dict( | ||
_scope_='mmrazor', | ||
_delete_=True, | ||
type='FpnTeacherDistill', | ||
architecture=student, | ||
teacher=dict( | ||
cfg_path='mmdet::retinanet/retinanet_x101-64x4d_fpn_1x_coco.py', | ||
pretrained=False), | ||
teacher_ckpt=teacher_ckpt, | ||
distiller=dict( | ||
type='ConfigurableDistiller', | ||
student_recorders=dict( | ||
fpn0=dict(type='ModuleOutputs', source='neck.fpn_convs.0.conv'), | ||
fpn1=dict(type='ModuleOutputs', source='neck.fpn_convs.1.conv'), | ||
fpn2=dict(type='ModuleOutputs', source='neck.fpn_convs.2.conv'), | ||
fpn3=dict(type='ModuleOutputs', source='neck.fpn_convs.3.conv'), | ||
fpn4=dict(type='ModuleOutputs', source='neck.fpn_convs.4.conv')), | ||
teacher_recorders=dict( | ||
fpn0=dict(type='ModuleOutputs', source='neck.fpn_convs.0.conv'), | ||
fpn1=dict(type='ModuleOutputs', source='neck.fpn_convs.1.conv'), | ||
fpn2=dict(type='ModuleOutputs', source='neck.fpn_convs.2.conv'), | ||
fpn3=dict(type='ModuleOutputs', source='neck.fpn_convs.3.conv'), | ||
fpn4=dict(type='ModuleOutputs', source='neck.fpn_convs.4.conv')), | ||
connectors=dict( | ||
s_fpn0_connector=dict( | ||
type='MGDConnector', | ||
student_channels=256, | ||
teacher_channels=256, | ||
lambda_mgd=0.65), | ||
s_fpn1_connector=dict( | ||
type='MGDConnector', | ||
student_channels=256, | ||
teacher_channels=256, | ||
lambda_mgd=0.65), | ||
s_fpn2_connector=dict( | ||
type='MGDConnector', | ||
student_channels=256, | ||
teacher_channels=256, | ||
lambda_mgd=0.65), | ||
s_fpn3_connector=dict( | ||
type='MGDConnector', | ||
student_channels=256, | ||
teacher_channels=256, | ||
lambda_mgd=0.65), | ||
s_fpn4_connector=dict( | ||
type='MGDConnector', | ||
student_channels=256, | ||
teacher_channels=256, | ||
lambda_mgd=0.65)), | ||
distill_losses=dict( | ||
loss_mgd_fpn0=dict(type='MGDLoss', alpha_mgd=0.00002), | ||
loss_mgd_fpn1=dict(type='MGDLoss', alpha_mgd=0.00002), | ||
loss_mgd_fpn2=dict(type='MGDLoss', alpha_mgd=0.00002), | ||
loss_mgd_fpn3=dict(type='MGDLoss', alpha_mgd=0.00002), | ||
loss_mgd_fpn4=dict(type='MGDLoss', alpha_mgd=0.00002)), | ||
loss_forward_mappings=dict( | ||
loss_mgd_fpn0=dict( | ||
preds_S=dict( | ||
from_student=True, | ||
recorder='fpn0', | ||
connector='s_fpn0_connector'), | ||
preds_T=dict(from_student=False, recorder='fpn0')), | ||
loss_mgd_fpn1=dict( | ||
preds_S=dict( | ||
from_student=True, | ||
recorder='fpn1', | ||
connector='s_fpn1_connector'), | ||
preds_T=dict(from_student=False, recorder='fpn1')), | ||
loss_mgd_fpn2=dict( | ||
preds_S=dict( | ||
from_student=True, | ||
recorder='fpn2', | ||
connector='s_fpn2_connector'), | ||
preds_T=dict(from_student=False, recorder='fpn2')), | ||
loss_mgd_fpn3=dict( | ||
preds_S=dict( | ||
from_student=True, | ||
recorder='fpn3', | ||
connector='s_fpn3_connector'), | ||
preds_T=dict(from_student=False, recorder='fpn3')), | ||
loss_mgd_fpn4=dict( | ||
preds_S=dict( | ||
from_student=True, | ||
recorder='fpn4', | ||
connector='s_fpn4_connector'), | ||
preds_T=dict(from_student=False, recorder='fpn4'))))) | ||
|
||
find_unused_parameters = True | ||
|
||
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop') | ||
|
||
optimizer_config = dict( | ||
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) | ||
|
||
param_scheduler = [ | ||
dict( | ||
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500), | ||
dict( | ||
type='MultiStepLR', | ||
begin=0, | ||
end=24, | ||
by_epoch=True, | ||
milestones=[16, 22], | ||
gamma=0.1) | ||
] | ||
|
||
optim_wrapper = dict( | ||
optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import Dict, Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from mmrazor.registry import MODELS | ||
from .base_connector import BaseConnector | ||
|
||
|
||
@MODELS.register_module() | ||
class MGDConnector(BaseConnector): | ||
"""PyTorch version of `Masked Generative Distillation. | ||
<https://arxiv.org/abs/2205.01529>` | ||
Args: | ||
student_channels(int): Number of channels in the student's feature map. | ||
teacher_channels(int): Number of channels in the teacher's feature map. | ||
lambda_mgd (float, optional): masked ratio. Defaults to 0.65 | ||
init_cfg (Optional[Dict], optional): The weight initialized config for | ||
:class:`BaseModule`. Defaults to None. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
student_channels: int, | ||
teacher_channels: int, | ||
lambda_mgd: float = 0.65, | ||
init_cfg: Optional[Dict] = None, | ||
) -> None: | ||
super().__init__(init_cfg) | ||
self.lambda_mgd = lambda_mgd | ||
if student_channels != teacher_channels: | ||
self.align = nn.Conv2d( | ||
student_channels, | ||
teacher_channels, | ||
kernel_size=1, | ||
stride=1, | ||
padding=0) | ||
else: | ||
self.align = None | ||
|
||
self.generation = nn.Sequential( | ||
nn.Conv2d( | ||
teacher_channels, teacher_channels, kernel_size=3, padding=1), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d( | ||
teacher_channels, teacher_channels, kernel_size=3, padding=1)) | ||
|
||
def forward_train(self, feature: torch.Tensor) -> torch.Tensor: | ||
if self.align is not None: | ||
feature = self.align(feature) | ||
|
||
N, C, H, W = feature.shape | ||
|
||
device = feature.device | ||
mat = torch.rand((N, 1, H, W)).to(device) | ||
mat = torch.where(mat > 1 - self.lambda_mgd, | ||
torch.zeros(1).to(device), | ||
torch.ones(1).to(device)).to(device) | ||
|
||
masked_fea = torch.mul(feature, mat) | ||
new_fea = self.generation(masked_fea) | ||
return new_fea |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from mmrazor.registry import MODELS | ||
|
||
|
||
@MODELS.register_module() | ||
class MGDLoss(nn.Module): | ||
"""PyTorch version of `Masked Generative Distillation. | ||
<https://arxiv.org/abs/2205.01529>` | ||
Args: | ||
alpha_mgd (float, optional): Weight of dis_loss. Defaults to 0.00002 | ||
""" | ||
|
||
def __init__(self, alpha_mgd: float = 0.00002) -> None: | ||
super(MGDLoss, self).__init__() | ||
self.alpha_mgd = alpha_mgd | ||
self.loss_mse = nn.MSELoss(reduction='sum') | ||
|
||
def forward(self, preds_S: torch.Tensor, | ||
preds_T: torch.Tensor) -> torch.Tensor: | ||
"""Forward function. | ||
Args: | ||
preds_S(torch.Tensor): Bs*C*H*W, student's feature map | ||
preds_T(torch.Tensor): Bs*C*H*W, teacher's feature map | ||
Return: | ||
torch.Tensor: The calculated loss value. | ||
""" | ||
assert preds_S.shape == preds_T.shape | ||
loss = self.get_dis_loss(preds_S, preds_T) * self.alpha_mgd | ||
|
||
return loss | ||
|
||
def get_dis_loss(self, preds_S: torch.Tensor, | ||
preds_T: torch.Tensor) -> torch.Tensor: | ||
"""Get MSE distance of preds_S and preds_T. | ||
Args: | ||
preds_S(torch.Tensor): Bs*C*H*W, student's feature map | ||
preds_T(torch.Tensor): Bs*C*H*W, teacher's feature map | ||
Return: | ||
torch.Tensor: The calculated mse distance value. | ||
""" | ||
N, C, H, W = preds_T.shape | ||
dis_loss = self.loss_mse(preds_S, preds_T) / N | ||
|
||
return dis_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters