Skip to content

Commit

Permalink
[CodeCamp #122] Support KD algorithm MGD for detection. (#377)
Browse files Browse the repository at this point in the history
* [Feature] Support KD algorithm MGD for detection.

* use connector to beauty mgd.

* fix typo, add unitest.

* fix mgd loss unitest.

* fix mgd connector unitest.

* add model pth and log file.

* add mAP.
  • Loading branch information
TinyTigerPan authored Jan 3, 2023
1 parent bcd6878 commit 5ebf839
Show file tree
Hide file tree
Showing 8 changed files with 294 additions and 5 deletions.
30 changes: 30 additions & 0 deletions configs/distill/mmdet/mgd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# MGD

> [Masked Generative Distillation](https://arxiv.org/abs/2205.01529)
<!-- [ALGORITHM] -->

## Abstract

Knowledge distillation has been applied to various tasks successfully. The current distillation algorithm usually improves students' performance by imitating the output of the teacher. This paper shows that teachers can also improve students' representation power by guiding students' feature recovery. From this point of view, we propose Masked Generative Distillation (MGD), which is simple: we mask random pixels of the student's feature and force it to generate the teacher's full feature through a simple block. MGD is a truly general feature-based distillation method, which can be utilized on various tasks, including image classification, object detection, semantic segmentation and instance segmentation. We experiment on different models with extensive datasets and the results show that all the students achieve excellent improvements. Notably, we boost ResNet-18 from 69.90% to 71.69% ImageNet top-1 accuracy, RetinaNet with ResNet-50 backbone from 37.4 to 41.0 Boundingbox mAP, SOLO based on ResNet-50 from 33.1 to 36.2 Mask mAP and DeepLabV3 based on ResNet-18 from 73.20 to 76.02 mIoU.

![pipeline](https://github.com/yzd-v/MGD/raw/master/architecture.png)

## Results and models

### Detection

| Location | Dataset | Teacher | Student | Lr schd | mAP | mAP(T) | mAP(S) | Config | Download |
| :------: | :-----: | :----------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------: | :-----: | :--: | :----: | :----: | :-------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| FPN | COCO | [RetinaNet-X101](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/retinanet/retinanet_x101-64x4d_fpn_1x_coco.py) | [RetinaNet-R50](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/retinanet/retinanet_r50_fpn_2x_coco.py) | 2x | 41.0 | 41.0 | 37.4 | [config](mgd_fpn_retina_x101_retina_r50_2x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/mgd/mgd_fpn_retina_x101_retina_r50_2x_coco_20221209_191847-87141529.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/mgd/mgd_fpn_retina_x101_retina_r50_2x_coco_20221209_191847-87141529.log) |

## Citation

```latex
@article{yang2022masked,
title={Masked Generative Distillation},
author={Yang, Zhendong and Li, Zhe and Shao, Mingqi and Shi, Dachuan and Yuan, Zehuan and Yuan, Chun},
journal={arXiv preprint arXiv:2205.01529},
year={2022}
}
```
118 changes: 118 additions & 0 deletions configs/distill/mmdet/mgd/mgd_fpn_retina_x101_retina_r50_2x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
_base_ = ['mmdet::retinanet/retinanet_r50_fpn_2x_coco.py']

teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth' # noqa: E501

student = _base_.model
student.neck.init_cfg = dict(
type='Pretrained', prefix='neck.', checkpoint=teacher_ckpt)
student.bbox_head.init_cfg = dict(
type='Pretrained', prefix='bbox_head.', checkpoint=teacher_ckpt)

model = dict(
_scope_='mmrazor',
_delete_=True,
type='FpnTeacherDistill',
architecture=student,
teacher=dict(
cfg_path='mmdet::retinanet/retinanet_x101-64x4d_fpn_1x_coco.py',
pretrained=False),
teacher_ckpt=teacher_ckpt,
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fpn0=dict(type='ModuleOutputs', source='neck.fpn_convs.0.conv'),
fpn1=dict(type='ModuleOutputs', source='neck.fpn_convs.1.conv'),
fpn2=dict(type='ModuleOutputs', source='neck.fpn_convs.2.conv'),
fpn3=dict(type='ModuleOutputs', source='neck.fpn_convs.3.conv'),
fpn4=dict(type='ModuleOutputs', source='neck.fpn_convs.4.conv')),
teacher_recorders=dict(
fpn0=dict(type='ModuleOutputs', source='neck.fpn_convs.0.conv'),
fpn1=dict(type='ModuleOutputs', source='neck.fpn_convs.1.conv'),
fpn2=dict(type='ModuleOutputs', source='neck.fpn_convs.2.conv'),
fpn3=dict(type='ModuleOutputs', source='neck.fpn_convs.3.conv'),
fpn4=dict(type='ModuleOutputs', source='neck.fpn_convs.4.conv')),
connectors=dict(
s_fpn0_connector=dict(
type='MGDConnector',
student_channels=256,
teacher_channels=256,
lambda_mgd=0.65),
s_fpn1_connector=dict(
type='MGDConnector',
student_channels=256,
teacher_channels=256,
lambda_mgd=0.65),
s_fpn2_connector=dict(
type='MGDConnector',
student_channels=256,
teacher_channels=256,
lambda_mgd=0.65),
s_fpn3_connector=dict(
type='MGDConnector',
student_channels=256,
teacher_channels=256,
lambda_mgd=0.65),
s_fpn4_connector=dict(
type='MGDConnector',
student_channels=256,
teacher_channels=256,
lambda_mgd=0.65)),
distill_losses=dict(
loss_mgd_fpn0=dict(type='MGDLoss', alpha_mgd=0.00002),
loss_mgd_fpn1=dict(type='MGDLoss', alpha_mgd=0.00002),
loss_mgd_fpn2=dict(type='MGDLoss', alpha_mgd=0.00002),
loss_mgd_fpn3=dict(type='MGDLoss', alpha_mgd=0.00002),
loss_mgd_fpn4=dict(type='MGDLoss', alpha_mgd=0.00002)),
loss_forward_mappings=dict(
loss_mgd_fpn0=dict(
preds_S=dict(
from_student=True,
recorder='fpn0',
connector='s_fpn0_connector'),
preds_T=dict(from_student=False, recorder='fpn0')),
loss_mgd_fpn1=dict(
preds_S=dict(
from_student=True,
recorder='fpn1',
connector='s_fpn1_connector'),
preds_T=dict(from_student=False, recorder='fpn1')),
loss_mgd_fpn2=dict(
preds_S=dict(
from_student=True,
recorder='fpn2',
connector='s_fpn2_connector'),
preds_T=dict(from_student=False, recorder='fpn2')),
loss_mgd_fpn3=dict(
preds_S=dict(
from_student=True,
recorder='fpn3',
connector='s_fpn3_connector'),
preds_T=dict(from_student=False, recorder='fpn3')),
loss_mgd_fpn4=dict(
preds_S=dict(
from_student=True,
recorder='fpn4',
connector='s_fpn4_connector'),
preds_T=dict(from_student=False, recorder='fpn4')))))

find_unused_parameters = True

val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')

optimizer_config = dict(
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))

param_scheduler = [
dict(
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500),
dict(
type='MultiStepLR',
begin=0,
end=24,
by_epoch=True,
milestones=[16, 22],
gamma=0.1)
]

optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001))
3 changes: 2 additions & 1 deletion mmrazor/models/architectures/connectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from .crd_connector import CRDConnector
from .factor_transfer_connectors import Paraphraser, Translator
from .fbkd_connector import FBKDStudentConnector, FBKDTeacherConnector
from .mgd_connector import MGDConnector
from .ofd_connector import OFDTeacherConnector
from .torch_connector import TorchFunctionalConnector, TorchNNConnector

__all__ = [
'ConvModuleConnector', 'Translator', 'Paraphraser', 'BYOTConnector',
'FBKDTeacherConnector', 'FBKDStudentConnector', 'TorchFunctionalConnector',
'CRDConnector', 'TorchNNConnector', 'OFDTeacherConnector'
'CRDConnector', 'TorchNNConnector', 'OFDTeacherConnector', 'MGDConnector'
]
65 changes: 65 additions & 0 deletions mmrazor/models/architectures/connectors/mgd_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional

import torch
import torch.nn as nn

from mmrazor.registry import MODELS
from .base_connector import BaseConnector


@MODELS.register_module()
class MGDConnector(BaseConnector):
"""PyTorch version of `Masked Generative Distillation.
<https://arxiv.org/abs/2205.01529>`
Args:
student_channels(int): Number of channels in the student's feature map.
teacher_channels(int): Number of channels in the teacher's feature map.
lambda_mgd (float, optional): masked ratio. Defaults to 0.65
init_cfg (Optional[Dict], optional): The weight initialized config for
:class:`BaseModule`. Defaults to None.
"""

def __init__(
self,
student_channels: int,
teacher_channels: int,
lambda_mgd: float = 0.65,
init_cfg: Optional[Dict] = None,
) -> None:
super().__init__(init_cfg)
self.lambda_mgd = lambda_mgd
if student_channels != teacher_channels:
self.align = nn.Conv2d(
student_channels,
teacher_channels,
kernel_size=1,
stride=1,
padding=0)
else:
self.align = None

self.generation = nn.Sequential(
nn.Conv2d(
teacher_channels, teacher_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(
teacher_channels, teacher_channels, kernel_size=3, padding=1))

def forward_train(self, feature: torch.Tensor) -> torch.Tensor:
if self.align is not None:
feature = self.align(feature)

N, C, H, W = feature.shape

device = feature.device
mat = torch.rand((N, 1, H, W)).to(device)
mat = torch.where(mat > 1 - self.lambda_mgd,
torch.zeros(1).to(device),
torch.ones(1).to(device)).to(device)

masked_fea = torch.mul(feature, mat)
new_fea = self.generation(masked_fea)
return new_fea
3 changes: 2 additions & 1 deletion mmrazor/models/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .kl_divergence import KLDivergence
from .l1_loss import L1Loss
from .l2_loss import L2Loss
from .mgd_loss import MGDLoss
from .ofd_loss import OFDLoss
from .pkd_loss import PKDLoss
from .relational_kd import AngleWiseRKD, DistanceWiseRKD
Expand All @@ -21,5 +22,5 @@
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD',
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss', 'ActivationLoss',
'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss', 'ATLoss', 'OFDLoss',
'L1Loss', 'FBKDLoss', 'CRDLoss', 'CrossEntropyLoss', 'PKDLoss'
'L1Loss', 'FBKDLoss', 'CRDLoss', 'CrossEntropyLoss', 'PKDLoss', 'MGDLoss'
]
54 changes: 54 additions & 0 deletions mmrazor/models/losses/mgd_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) OpenMMLab. All rights reserved.

import torch
import torch.nn as nn

from mmrazor.registry import MODELS


@MODELS.register_module()
class MGDLoss(nn.Module):
"""PyTorch version of `Masked Generative Distillation.
<https://arxiv.org/abs/2205.01529>`
Args:
alpha_mgd (float, optional): Weight of dis_loss. Defaults to 0.00002
"""

def __init__(self, alpha_mgd: float = 0.00002) -> None:
super(MGDLoss, self).__init__()
self.alpha_mgd = alpha_mgd
self.loss_mse = nn.MSELoss(reduction='sum')

def forward(self, preds_S: torch.Tensor,
preds_T: torch.Tensor) -> torch.Tensor:
"""Forward function.
Args:
preds_S(torch.Tensor): Bs*C*H*W, student's feature map
preds_T(torch.Tensor): Bs*C*H*W, teacher's feature map
Return:
torch.Tensor: The calculated loss value.
"""
assert preds_S.shape == preds_T.shape
loss = self.get_dis_loss(preds_S, preds_T) * self.alpha_mgd

return loss

def get_dis_loss(self, preds_S: torch.Tensor,
preds_T: torch.Tensor) -> torch.Tensor:
"""Get MSE distance of preds_S and preds_T.
Args:
preds_S(torch.Tensor): Bs*C*H*W, student's feature map
preds_T(torch.Tensor): Bs*C*H*W, teacher's feature map
Return:
torch.Tensor: The calculated mse distance value.
"""
N, C, H, W = preds_T.shape
dis_loss = self.loss_mse(preds_S, preds_T) / N

return dis_loss
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

from mmrazor.models import (BYOTConnector, ConvModuleConnector, CRDConnector,
FBKDStudentConnector, FBKDTeacherConnector,
Paraphraser, TorchFunctionalConnector,
TorchNNConnector, Translator)
MGDConnector, Paraphraser,
TorchFunctionalConnector, TorchNNConnector,
Translator)


class TestConnector(TestCase):
Expand Down Expand Up @@ -130,3 +131,15 @@ def test_torch_connector(self):
functional_pool_connector = TorchFunctionalConnector()
with self.assertRaises(ValueError):
functional_pool_connector = TorchNNConnector(module_name='fake')

def test_mgd_connector(self):
s_feat = torch.randn(1, 16, 8, 8)
mgd_connector1 = MGDConnector(
student_channels=16, teacher_channels=16, lambda_mgd=0.65)
mgd_connector2 = MGDConnector(
student_channels=16, teacher_channels=32, lambda_mgd=0.65)
s_output1 = mgd_connector1.forward_train(s_feat)
s_output2 = mgd_connector2.forward_train(s_feat)

assert s_output1.shape == torch.Size([1, 16, 8, 8])
assert s_output2.shape == torch.Size([1, 32, 8, 8])
9 changes: 8 additions & 1 deletion tests/test_models/test_losses/test_distillation_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from mmrazor import digit_version
from mmrazor.models import (ABLoss, ActivationLoss, ATLoss, CRDLoss, DKDLoss,
FBKDLoss, FTLoss, InformationEntropyLoss,
KDSoftCELoss, OFDLoss, OnehotLikeLoss, PKDLoss)
KDSoftCELoss, MGDLoss, OFDLoss, OnehotLikeLoss,
PKDLoss)


class TestLosses(TestCase):
Expand Down Expand Up @@ -204,3 +205,9 @@ def test_pkdloss(self):
loss = pkd_loss(feats_S, feats_T)
self.assertTrue(loss.numel() == 1)
self.assertTrue(0. <= loss <= 1.)

def test_mgd_loss(self):
mgd_loss = MGDLoss(alpha_mgd=0.00002)
feats_S, feats_T = torch.rand(2, 256, 4, 4), torch.rand(2, 256, 4, 4)
loss = mgd_loss(feats_S, feats_T)
self.assertTrue(loss.numel() == 1)

0 comments on commit 5ebf839

Please sign in to comment.