-
Notifications
You must be signed in to change notification settings - Fork 228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CodeCamp #122] Support KD algorithm MGD for detection. #377
Merged
Merged
Changes from 5 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
d070987
[Feature] Support KD algorithm MGD for detection.
TinyTigerPan 5c54914
use connector to beauty mgd.
TinyTigerPan a9c9c05
fix typo, add unitest.
TinyTigerPan f391a59
fix mgd loss unitest.
TinyTigerPan 522e6f7
fix mgd connector unitest.
TinyTigerPan c96c068
add model pth and log file.
TinyTigerPan f221c20
add mAP.
TinyTigerPan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# MGD | ||
|
||
> [Masked Generative Distillation](https://arxiv.org/abs/2205.01529) | ||
|
||
<!-- [ALGORITHM] --> | ||
|
||
## Abstract | ||
|
||
Knowledge distillation has been applied to various tasks successfully. The current distillation algorithm usually improves students' performance by imitating the output of the teacher. This paper shows that teachers can also improve students' representation power by guiding students' feature recovery. From this point of view, we propose Masked Generative Distillation (MGD), which is simple: we mask random pixels of the student's feature and force it to generate the teacher's full feature through a simple block. MGD is a truly general feature-based distillation method, which can be utilized on various tasks, including image classification, object detection, semantic segmentation and instance segmentation. We experiment on different models with extensive datasets and the results show that all the students achieve excellent improvements. Notably, we boost ResNet-18 from 69.90% to 71.69% ImageNet top-1 accuracy, RetinaNet with ResNet-50 backbone from 37.4 to 41.0 Boundingbox mAP, SOLO based on ResNet-50 from 33.1 to 36.2 Mask mAP and DeepLabV3 based on ResNet-18 from 73.20 to 76.02 mIoU. | ||
|
||
![pipeline](https://github.com/yzd-v/MGD/raw/master/architecture.png) | ||
|
||
## Results and models | ||
|
||
### Detection | ||
|
||
| Location | Dataset | Teacher | Student | Lr schd | mAP | mAP(T) | mAP(S) | Config | Download | | ||
| :------: | :-----: | :----------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------: | :-----: | :-: | :----: | :----: | :-------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ||
| FPN | COCO | [RetinaNet-X101](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/retinanet/retinanet_x101-64x4d_fpn_1x_coco.py) | [RetinaNet-R50](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/retinanet/retinanet_r50_fpn_2x_coco.py) | 2x | | 41.0 | 37.4 | [config](mgd_fpn_retina_x101_retina_r50_2x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth) \|[model](<>) \| [log](<>) | | ||
|
||
## Citation | ||
|
||
```latex | ||
@article{yang2022masked, | ||
title={Masked Generative Distillation}, | ||
author={Yang, Zhendong and Li, Zhe and Shao, Mingqi and Shi, Dachuan and Yuan, Zehuan and Yuan, Chun}, | ||
journal={arXiv preprint arXiv:2205.01529}, | ||
year={2022} | ||
} | ||
``` |
118 changes: 118 additions & 0 deletions
118
configs/distill/mmdet/mgd/mgd_fpn_retina_x101_retina_r50_2x_coco.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
_base_ = ['mmdet::retinanet/retinanet_r50_fpn_2x_coco.py'] | ||
|
||
teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth' # noqa: E501 | ||
|
||
student = _base_.model | ||
student.neck.init_cfg = dict( | ||
type='Pretrained', prefix='neck.', checkpoint=teacher_ckpt) | ||
student.bbox_head.init_cfg = dict( | ||
type='Pretrained', prefix='bbox_head.', checkpoint=teacher_ckpt) | ||
|
||
model = dict( | ||
_scope_='mmrazor', | ||
_delete_=True, | ||
type='FpnTeacherDistill', | ||
architecture=student, | ||
teacher=dict( | ||
cfg_path='mmdet::retinanet/retinanet_x101-64x4d_fpn_1x_coco.py', | ||
pretrained=False), | ||
teacher_ckpt=teacher_ckpt, | ||
distiller=dict( | ||
type='ConfigurableDistiller', | ||
student_recorders=dict( | ||
fpn0=dict(type='ModuleOutputs', source='neck.fpn_convs.0.conv'), | ||
fpn1=dict(type='ModuleOutputs', source='neck.fpn_convs.1.conv'), | ||
fpn2=dict(type='ModuleOutputs', source='neck.fpn_convs.2.conv'), | ||
fpn3=dict(type='ModuleOutputs', source='neck.fpn_convs.3.conv'), | ||
fpn4=dict(type='ModuleOutputs', source='neck.fpn_convs.4.conv')), | ||
teacher_recorders=dict( | ||
fpn0=dict(type='ModuleOutputs', source='neck.fpn_convs.0.conv'), | ||
fpn1=dict(type='ModuleOutputs', source='neck.fpn_convs.1.conv'), | ||
fpn2=dict(type='ModuleOutputs', source='neck.fpn_convs.2.conv'), | ||
fpn3=dict(type='ModuleOutputs', source='neck.fpn_convs.3.conv'), | ||
fpn4=dict(type='ModuleOutputs', source='neck.fpn_convs.4.conv')), | ||
connectors=dict( | ||
s_fpn0_connector=dict( | ||
type='MGDConnector', | ||
student_channels=256, | ||
teacher_channels=256, | ||
lambda_mgd=0.65), | ||
s_fpn1_connector=dict( | ||
type='MGDConnector', | ||
student_channels=256, | ||
teacher_channels=256, | ||
lambda_mgd=0.65), | ||
s_fpn2_connector=dict( | ||
type='MGDConnector', | ||
student_channels=256, | ||
teacher_channels=256, | ||
lambda_mgd=0.65), | ||
s_fpn3_connector=dict( | ||
type='MGDConnector', | ||
student_channels=256, | ||
teacher_channels=256, | ||
lambda_mgd=0.65), | ||
s_fpn4_connector=dict( | ||
type='MGDConnector', | ||
student_channels=256, | ||
teacher_channels=256, | ||
lambda_mgd=0.65)), | ||
distill_losses=dict( | ||
loss_mgd_fpn0=dict(type='MGDLoss', alpha_mgd=0.00002), | ||
loss_mgd_fpn1=dict(type='MGDLoss', alpha_mgd=0.00002), | ||
loss_mgd_fpn2=dict(type='MGDLoss', alpha_mgd=0.00002), | ||
loss_mgd_fpn3=dict(type='MGDLoss', alpha_mgd=0.00002), | ||
loss_mgd_fpn4=dict(type='MGDLoss', alpha_mgd=0.00002)), | ||
loss_forward_mappings=dict( | ||
loss_mgd_fpn0=dict( | ||
preds_S=dict( | ||
from_student=True, | ||
recorder='fpn0', | ||
connector='s_fpn0_connector'), | ||
preds_T=dict(from_student=False, recorder='fpn0')), | ||
loss_mgd_fpn1=dict( | ||
preds_S=dict( | ||
from_student=True, | ||
recorder='fpn1', | ||
connector='s_fpn1_connector'), | ||
preds_T=dict(from_student=False, recorder='fpn1')), | ||
loss_mgd_fpn2=dict( | ||
preds_S=dict( | ||
from_student=True, | ||
recorder='fpn2', | ||
connector='s_fpn2_connector'), | ||
preds_T=dict(from_student=False, recorder='fpn2')), | ||
loss_mgd_fpn3=dict( | ||
preds_S=dict( | ||
from_student=True, | ||
recorder='fpn3', | ||
connector='s_fpn3_connector'), | ||
preds_T=dict(from_student=False, recorder='fpn3')), | ||
loss_mgd_fpn4=dict( | ||
preds_S=dict( | ||
from_student=True, | ||
recorder='fpn4', | ||
connector='s_fpn4_connector'), | ||
preds_T=dict(from_student=False, recorder='fpn4'))))) | ||
|
||
find_unused_parameters = True | ||
|
||
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop') | ||
|
||
optimizer_config = dict( | ||
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) | ||
|
||
param_scheduler = [ | ||
dict( | ||
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500), | ||
dict( | ||
type='MultiStepLR', | ||
begin=0, | ||
end=24, | ||
by_epoch=True, | ||
milestones=[16, 22], | ||
gamma=0.1) | ||
] | ||
|
||
optim_wrapper = dict( | ||
optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import Dict, Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from mmrazor.registry import MODELS | ||
from .base_connector import BaseConnector | ||
|
||
|
||
@MODELS.register_module() | ||
class MGDConnector(BaseConnector): | ||
"""PyTorch version of `Masked Generative Distillation. | ||
|
||
<https://arxiv.org/abs/2205.01529>` | ||
|
||
Args: | ||
student_channels(int): Number of channels in the student's feature map. | ||
teacher_channels(int): Number of channels in the teacher's feature map. | ||
lambda_mgd (float, optional): masked ratio. Defaults to 0.65 | ||
init_cfg (Optional[Dict], optional): The weight initialized config for | ||
:class:`BaseModule`. Defaults to None. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
student_channels: int, | ||
teacher_channels: int, | ||
lambda_mgd: float = 0.65, | ||
init_cfg: Optional[Dict] = None, | ||
) -> None: | ||
super().__init__(init_cfg) | ||
self.lambda_mgd = lambda_mgd | ||
if student_channels != teacher_channels: | ||
self.align = nn.Conv2d( | ||
student_channels, | ||
teacher_channels, | ||
kernel_size=1, | ||
stride=1, | ||
padding=0) | ||
else: | ||
self.align = None | ||
|
||
self.generation = nn.Sequential( | ||
nn.Conv2d( | ||
teacher_channels, teacher_channels, kernel_size=3, padding=1), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d( | ||
teacher_channels, teacher_channels, kernel_size=3, padding=1)) | ||
|
||
def forward_train(self, feature: torch.Tensor) -> torch.Tensor: | ||
if self.align is not None: | ||
feature = self.align(feature) | ||
|
||
N, C, H, W = feature.shape | ||
|
||
device = feature.device | ||
mat = torch.rand((N, 1, H, W)).to(device) | ||
mat = torch.where(mat > 1 - self.lambda_mgd, | ||
torch.zeros(1).to(device), | ||
torch.ones(1).to(device)).to(device) | ||
|
||
masked_fea = torch.mul(feature, mat) | ||
new_fea = self.generation(masked_fea) | ||
return new_fea |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from mmrazor.registry import MODELS | ||
|
||
|
||
@MODELS.register_module() | ||
class MGDLoss(nn.Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add the test case of |
||
"""PyTorch version of `Masked Generative Distillation. | ||
|
||
<https://arxiv.org/abs/2205.01529>` | ||
|
||
Args: | ||
alpha_mgd (float, optional): Weight of dis_loss. Defaults to 0.00002 | ||
""" | ||
|
||
def __init__(self, alpha_mgd: float = 0.00002) -> None: | ||
super(MGDLoss, self).__init__() | ||
self.alpha_mgd = alpha_mgd | ||
self.loss_mse = nn.MSELoss(reduction='sum') | ||
|
||
def forward(self, preds_S: torch.Tensor, | ||
preds_T: torch.Tensor) -> torch.Tensor: | ||
"""Forward function. | ||
|
||
Args: | ||
preds_S(torch.Tensor): Bs*C*H*W, student's feature map | ||
preds_T(torch.Tensor): Bs*C*H*W, teacher's feature map | ||
|
||
Return: | ||
torch.Tensor: The calculated loss value. | ||
""" | ||
assert preds_S.shape == preds_T.shape | ||
loss = self.get_dis_loss(preds_S, preds_T) * self.alpha_mgd | ||
|
||
return loss | ||
|
||
def get_dis_loss(self, preds_S: torch.Tensor, | ||
preds_T: torch.Tensor) -> torch.Tensor: | ||
"""Get MSE distance of preds_S and preds_T. | ||
|
||
Args: | ||
preds_S(torch.Tensor): Bs*C*H*W, student's feature map | ||
preds_T(torch.Tensor): Bs*C*H*W, teacher's feature map | ||
|
||
Return: | ||
torch.Tensor: The calculated mse distance value. | ||
""" | ||
N, C, H, W = preds_T.shape | ||
dis_loss = self.loss_mse(preds_S, preds_T) / N | ||
|
||
return dis_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please contact @pppppM to upload the relative
log
andmodel
files.