Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeCamp #122] Support KD algorithm MGD for detection. #377

Merged
merged 7 commits into from
Jan 3, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions configs/distill/mmdet/mgd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# MGD

> [Masked Generative Distillation](https://arxiv.org/abs/2205.01529)

<!-- [ALGORITHM] -->

## Abstract

Knowledge distillation has been applied to various tasks successfully. The current distillation algorithm usually improves students' performance by imitating the output of the teacher. This paper shows that teachers can also improve students' representation power by guiding students' feature recovery. From this point of view, we propose Masked Generative Distillation (MGD), which is simple: we mask random pixels of the student's feature and force it to generate the teacher's full feature through a simple block. MGD is a truly general feature-based distillation method, which can be utilized on various tasks, including image classification, object detection, semantic segmentation and instance segmentation. We experiment on different models with extensive datasets and the results show that all the students achieve excellent improvements. Notably, we boost ResNet-18 from 69.90% to 71.69% ImageNet top-1 accuracy, RetinaNet with ResNet-50 backbone from 37.4 to 41.0 Boundingbox mAP, SOLO based on ResNet-50 from 33.1 to 36.2 Mask mAP and DeepLabV3 based on ResNet-18 from 73.20 to 76.02 mIoU.

![pipeline](https://github.com/yzd-v/MGD/raw/master/architecture.png)

## Results and models

### Detection

| Location | Dataset | Teacher | Student | Lr schd | mAP | mAP(T) | mAP(S) | Config | Download |
| :------: | :-----: | :----------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------: | :-----: | :-: | :----: | :----: | :-------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| FPN | COCO | [RetinaNet-X101](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/retinanet/retinanet_x101-64x4d_fpn_1x_coco.py) | [RetinaNet-R50](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/retinanet/retinanet_r50_fpn_2x_coco.py) | 2x | | 41.0 | 37.4 | [config](mgd_fpn_retina_x101_retina_r50_2x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth) \|[model](<>) \| [log](<>) |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please contact @pppppM to upload the relative log and model files.


## Citation

```latex
@article{yang2022masked,
title={Masked Generative Distillation},
author={Yang, Zhendong and Li, Zhe and Shao, Mingqi and Shi, Dachuan and Yuan, Zehuan and Yuan, Chun},
journal={arXiv preprint arXiv:2205.01529},
year={2022}
}
```
102 changes: 102 additions & 0 deletions configs/distill/mmdet/mgd/mgd_fpn_retina_x101_retina_r50_2x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
_base_ = ['mmdet::retinanet/retinanet_r50_fpn_2x_coco.py']

teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth' # noqa: E501

student = _base_.model
student.neck.init_cfg = dict(
type='Pretrained', prefix='neck.', checkpoint=teacher_ckpt)
student.bbox_head.init_cfg = dict(
type='Pretrained', prefix='bbox_head.', checkpoint=teacher_ckpt)

model = dict(
_scope_='mmrazor',
_delete_=True,
type='FpnTeacherDistill',
architecture=student,
teacher=dict(
cfg_path='mmdet::retinanet/retinanet_x101-64x4d_fpn_1x_coco.py',
pretrained=False),
teacher_ckpt=teacher_ckpt,
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fpn0=dict(type='ModuleOutputs', source='neck.fpn_convs.0.conv'),
fpn1=dict(type='ModuleOutputs', source='neck.fpn_convs.1.conv'),
fpn2=dict(type='ModuleOutputs', source='neck.fpn_convs.2.conv'),
fpn3=dict(type='ModuleOutputs', source='neck.fpn_convs.3.conv'),
fpn4=dict(type='ModuleOutputs', source='neck.fpn_convs.4.conv')),
teacher_recorders=dict(
fpn0=dict(type='ModuleOutputs', source='neck.fpn_convs.0.conv'),
fpn1=dict(type='ModuleOutputs', source='neck.fpn_convs.1.conv'),
fpn2=dict(type='ModuleOutputs', source='neck.fpn_convs.2.conv'),
fpn3=dict(type='ModuleOutputs', source='neck.fpn_convs.3.conv'),
fpn4=dict(type='ModuleOutputs', source='neck.fpn_convs.4.conv')),
distill_losses=dict(
loss_mgd_fpn0=dict(
type='MGDLoss',
student_channels=256,
teacher_channels=256,
alpha_mgd=0.00002,
lambda_mgd=0.65),
loss_mgd_fpn1=dict(
type='MGDLoss',
student_channels=256,
teacher_channels=256,
alpha_mgd=0.00002,
lambda_mgd=0.65),
loss_mgd_fpn2=dict(
type='MGDLoss',
student_channels=256,
teacher_channels=256,
alpha_mgd=0.00002,
lambda_mgd=0.65),
loss_mgd_fpn3=dict(
type='MGDLoss',
student_channels=256,
teacher_channels=256,
alpha_mgd=0.00002,
lambda_mgd=0.65),
loss_mgd_fpn4=dict(
type='MGDLoss',
student_channels=256,
teacher_channels=256,
alpha_mgd=0.00002,
lambda_mgd=0.65)),
loss_forward_mappings=dict(
loss_mgd_fpn0=dict(
preds_S=dict(from_student=True, recorder='fpn0'),
preds_T=dict(from_student=False, recorder='fpn0')),
loss_mgd_fpn1=dict(
preds_S=dict(from_student=True, recorder='fpn1'),
preds_T=dict(from_student=False, recorder='fpn1')),
loss_mgd_fpn2=dict(
preds_S=dict(from_student=True, recorder='fpn2'),
preds_T=dict(from_student=False, recorder='fpn2')),
loss_mgd_fpn3=dict(
preds_S=dict(from_student=True, recorder='fpn3'),
preds_T=dict(from_student=False, recorder='fpn3')),
loss_mgd_fpn4=dict(
preds_S=dict(from_student=True, recorder='fpn4'),
preds_T=dict(from_student=False, recorder='fpn4')))))

find_unused_parameters = True

val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')

optimizer_config = dict(
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))

param_scheduler = [
dict(
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500),
dict(
type='MultiStepLR',
begin=0,
end=24,
by_epoch=True,
milestones=[16, 22],
gamma=0.1)
]

optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001))
3 changes: 2 additions & 1 deletion mmrazor/models/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .kl_divergence import KLDivergence
from .l1_loss import L1Loss
from .l2_loss import L2Loss
from .mgd_loss import MGDLoss
from .ofd_loss import OFDLoss
from .pkd_loss import PKDLoss
from .relational_kd import AngleWiseRKD, DistanceWiseRKD
Expand All @@ -21,5 +22,5 @@
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD',
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss', 'ActivationLoss',
'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss', 'ATLoss', 'OFDLoss',
'L1Loss', 'FBKDLoss', 'CRDLoss', 'CrossEntropyLoss', 'PKDLoss'
'L1Loss', 'FBKDLoss', 'CRDLoss', 'CrossEntropyLoss', 'PKDLoss', 'MGDLoss'
]
80 changes: 80 additions & 0 deletions mmrazor/models/losses/mgd_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) OpenMMLab. All rights reserved.

import torch
import torch.nn as nn

from mmrazor.registry import MODELS


@MODELS.register_module()
class MGDLoss(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the test case of MGDLoss in test_distillation_losses.
Add the docstring of get_dis_loss.

"""PyTorch version of `Masked Generative Distillation.

<https://arxiv.org/abs/2205.01529>`

Args:
student_channels(int): Number of channels in the student's feature map.
teacher_channels(int): Number of channels in the teacher's feature map.
name (str): the loss name of the layer
alpha_mgd (float, optional): Weight of dis_loss. Defaults to 0.00002
lambda_mgd (float, optional): masked ratio. Defaults to 0.65
"""

def __init__(
self,
student_channels,
teacher_channels,
alpha_mgd=0.00002,
lambda_mgd=0.65,
):
super(MGDLoss, self).__init__()
self.alpha_mgd = alpha_mgd
self.lambda_mgd = lambda_mgd

if student_channels != teacher_channels:
self.align = nn.Conv2d(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We tend you to move the align module and generation module here out of mgdloss and form a connector in connectors.
Doing so can realize the combination of new connector and various losses other than mseloss here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will change it later.

student_channels,
teacher_channels,
kernel_size=1,
stride=1,
padding=0)
else:
self.align = None

self.generation = nn.Sequential(
nn.Conv2d(
teacher_channels, teacher_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(
teacher_channels, teacher_channels, kernel_size=3, padding=1))

def forward(self, preds_S, preds_T):
"""Forward function.

Args:
preds_S(Tensor): Bs*C*H*W, student's feature map
preds_T(Tensor): Bs*C*H*W, teacher's feature map
"""
assert preds_S.shape[-2:] == preds_T.shape[-2:]

if self.align is not None:
preds_S = self.align(preds_S)

loss = self.get_dis_loss(preds_S, preds_T) * self.alpha_mgd

return loss

def get_dis_loss(self, preds_S, preds_T):
loss_mse = nn.MSELoss(reduction='sum')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Build the loss_mse in __init__ to avoid repeated builds of the loss function during training.

N, C, H, W = preds_T.shape

device = preds_S.device
mat = torch.rand((N, 1, H, W)).to(device)
mat = torch.where(mat > 1 - self.lambda_mgd, 0, 1).to(device)

masked_fea = torch.mul(preds_S, mat)
new_fea = self.generation(masked_fea)

dis_loss = loss_mse(new_fea, preds_T) / N

return dis_loss