Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeCamp #122] Support KD algorithm MGD for detection. #377

Merged
merged 7 commits into from
Jan 3, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions configs/distill/mmdet/mgd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# MGD

> [Masked Generative Distillation](https://arxiv.org/abs/2205.01529)

<!-- [ALGORITHM] -->

## Abstract

Knowledge distillation has been applied to various tasks successfully. The current distillation algorithm usually improves students' performance by imitating the output of the teacher. This paper shows that teachers can also improve students' representation power by guiding students' feature recovery. From this point of view, we propose Masked Generative Distillation (MGD), which is simple: we mask random pixels of the student's feature and force it to generate the teacher's full feature through a simple block. MGD is a truly general feature-based distillation method, which can be utilized on various tasks, including image classification, object detection, semantic segmentation and instance segmentation. We experiment on different models with extensive datasets and the results show that all the students achieve excellent improvements. Notably, we boost ResNet-18 from 69.90% to 71.69% ImageNet top-1 accuracy, RetinaNet with ResNet-50 backbone from 37.4 to 41.0 Boundingbox mAP, SOLO based on ResNet-50 from 33.1 to 36.2 Mask mAP and DeepLabV3 based on ResNet-18 from 73.20 to 76.02 mIoU.

![pipeline](https://github.com/yzd-v/MGD/raw/master/architecture.png)

## Results and models

### Detection

| Location | Dataset | Teacher | Student | Lr schd | mAP | mAP(T) | mAP(S) | Config | Download |
| :------: | :-----: | :----------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------: | :-----: | :-: | :----: | :----: | :-------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| FPN | COCO | [RetinaNet-X101](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/retinanet/retinanet_x101-64x4d_fpn_1x_coco.py) | [RetinaNet-R50](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/retinanet/retinanet_r50_fpn_2x_coco.py) | 2x | | 41.0 | 37.4 | [config](mgd_fpn_retina_x101_retina_r50_2x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth) \|[model](<>) \| [log](<>) |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please contact @pppppM to upload the relative log and model files.


## Citation

```latex
@article{yang2022masked,
title={Masked Generative Distillation},
author={Yang, Zhendong and Li, Zhe and Shao, Mingqi and Shi, Dachuan and Yuan, Zehuan and Yuan, Chun},
journal={arXiv preprint arXiv:2205.01529},
year={2022}
}
```
118 changes: 118 additions & 0 deletions configs/distill/mmdet/mgd/mgd_fpn_retina_x101_retina_r50_2x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
_base_ = ['mmdet::retinanet/retinanet_r50_fpn_2x_coco.py']

teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth' # noqa: E501

student = _base_.model
student.neck.init_cfg = dict(
type='Pretrained', prefix='neck.', checkpoint=teacher_ckpt)
student.bbox_head.init_cfg = dict(
type='Pretrained', prefix='bbox_head.', checkpoint=teacher_ckpt)

model = dict(
_scope_='mmrazor',
_delete_=True,
type='FpnTeacherDistill',
architecture=student,
teacher=dict(
cfg_path='mmdet::retinanet/retinanet_x101-64x4d_fpn_1x_coco.py',
pretrained=False),
teacher_ckpt=teacher_ckpt,
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fpn0=dict(type='ModuleOutputs', source='neck.fpn_convs.0.conv'),
fpn1=dict(type='ModuleOutputs', source='neck.fpn_convs.1.conv'),
fpn2=dict(type='ModuleOutputs', source='neck.fpn_convs.2.conv'),
fpn3=dict(type='ModuleOutputs', source='neck.fpn_convs.3.conv'),
fpn4=dict(type='ModuleOutputs', source='neck.fpn_convs.4.conv')),
teacher_recorders=dict(
fpn0=dict(type='ModuleOutputs', source='neck.fpn_convs.0.conv'),
fpn1=dict(type='ModuleOutputs', source='neck.fpn_convs.1.conv'),
fpn2=dict(type='ModuleOutputs', source='neck.fpn_convs.2.conv'),
fpn3=dict(type='ModuleOutputs', source='neck.fpn_convs.3.conv'),
fpn4=dict(type='ModuleOutputs', source='neck.fpn_convs.4.conv')),
connectors=dict(
s_fpn0_connector=dict(
type='MGDConnector',
student_channels=256,
teacher_channels=256,
lambda_mgd=0.65),
s_fpn1_connector=dict(
type='MGDConnector',
student_channels=256,
teacher_channels=256,
lambda_mgd=0.65),
s_fpn2_connector=dict(
type='MGDConnector',
student_channels=256,
teacher_channels=256,
lambda_mgd=0.65),
s_fpn3_connector=dict(
type='MGDConnector',
student_channels=256,
teacher_channels=256,
lambda_mgd=0.65),
s_fpn4_connector=dict(
type='MGDConnector',
student_channels=256,
teacher_channels=256,
lambda_mgd=0.65)),
distill_losses=dict(
loss_mgd_fpn0=dict(type='MGDLoss', alpha_mgd=0.00002),
loss_mgd_fpn1=dict(type='MGDLoss', alpha_mgd=0.00002),
loss_mgd_fpn2=dict(type='MGDLoss', alpha_mgd=0.00002),
loss_mgd_fpn3=dict(type='MGDLoss', alpha_mgd=0.00002),
loss_mgd_fpn4=dict(type='MGDLoss', alpha_mgd=0.00002)),
loss_forward_mappings=dict(
loss_mgd_fpn0=dict(
preds_S=dict(
from_student=True,
recorder='fpn0',
connector='s_fpn0_connector'),
preds_T=dict(from_student=False, recorder='fpn0')),
loss_mgd_fpn1=dict(
preds_S=dict(
from_student=True,
recorder='fpn1',
connector='s_fpn1_connector'),
preds_T=dict(from_student=False, recorder='fpn1')),
loss_mgd_fpn2=dict(
preds_S=dict(
from_student=True,
recorder='fpn2',
connector='s_fpn2_connector'),
preds_T=dict(from_student=False, recorder='fpn2')),
loss_mgd_fpn3=dict(
preds_S=dict(
from_student=True,
recorder='fpn3',
connector='s_fpn3_connector'),
preds_T=dict(from_student=False, recorder='fpn3')),
loss_mgd_fpn4=dict(
preds_S=dict(
from_student=True,
recorder='fpn4',
connector='s_fpn4_connector'),
preds_T=dict(from_student=False, recorder='fpn4')))))

find_unused_parameters = True

val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')

optimizer_config = dict(
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))

param_scheduler = [
dict(
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500),
dict(
type='MultiStepLR',
begin=0,
end=24,
by_epoch=True,
milestones=[16, 22],
gamma=0.1)
]

optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001))
3 changes: 2 additions & 1 deletion mmrazor/models/architectures/connectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from .crd_connector import CRDConnector
from .factor_transfer_connectors import Paraphraser, Translator
from .fbkd_connector import FBKDStudentConnector, FBKDTeacherConnector
from .mgd_connector import MGDConnector
from .ofd_connector import OFDTeacherConnector
from .torch_connector import TorchFunctionalConnector, TorchNNConnector

__all__ = [
'ConvModuleConncetor', 'Translator', 'Paraphraser', 'BYOTConnector',
'FBKDTeacherConnector', 'FBKDStudentConnector', 'TorchFunctionalConnector',
'CRDConnector', 'TorchNNConnector', 'OFDTeacherConnector'
'CRDConnector', 'TorchNNConnector', 'OFDTeacherConnector', 'MGDConnector'
]
65 changes: 65 additions & 0 deletions mmrazor/models/architectures/connectors/mgd_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional

import torch
import torch.nn as nn

from mmrazor.registry import MODELS
from .base_connector import BaseConnector


@MODELS.register_module()
class MGDConnector(BaseConnector):
"""PyTorch version of `Masked Generative Distillation.

<https://arxiv.org/abs/2205.01529>`

Args:
student_channels(int): Number of channels in the student's feature map.
teacher_channels(int): Number of channels in the teacher's feature map.
lambda_mgd (float, optional): masked ratio. Defaults to 0.65
init_cfg (Optional[Dict], optional): The weight initialized config for
:class:`BaseModule`. Defaults to None.
"""

def __init__(
self,
student_channels: int,
teacher_channels: int,
lambda_mgd: float = 0.65,
init_cfg: Optional[Dict] = None,
) -> None:
super().__init__(init_cfg)
self.lambda_mgd = lambda_mgd
if student_channels != teacher_channels:
self.align = nn.Conv2d(
student_channels,
teacher_channels,
kernel_size=1,
stride=1,
padding=0)
else:
self.align = None

self.generation = nn.Sequential(
nn.Conv2d(
teacher_channels, teacher_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(
teacher_channels, teacher_channels, kernel_size=3, padding=1))

def forward_train(self, feature: torch.Tensor) -> torch.Tensor:
if self.align is not None:
feature = self.align(feature)

N, C, H, W = feature.shape

device = feature.device
mat = torch.rand((N, 1, H, W)).to(device)
mat = torch.where(mat > 1 - self.lambda_mgd,
torch.zeros(1).to(device),
torch.ones(1).to(device)).to(device)

masked_fea = torch.mul(feature, mat)
new_fea = self.generation(masked_fea)
return new_fea
3 changes: 2 additions & 1 deletion mmrazor/models/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .kl_divergence import KLDivergence
from .l1_loss import L1Loss
from .l2_loss import L2Loss
from .mgd_loss import MGDLoss
from .ofd_loss import OFDLoss
from .pkd_loss import PKDLoss
from .relational_kd import AngleWiseRKD, DistanceWiseRKD
Expand All @@ -21,5 +22,5 @@
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD',
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss', 'ActivationLoss',
'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss', 'ATLoss', 'OFDLoss',
'L1Loss', 'FBKDLoss', 'CRDLoss', 'CrossEntropyLoss', 'PKDLoss'
'L1Loss', 'FBKDLoss', 'CRDLoss', 'CrossEntropyLoss', 'PKDLoss', 'MGDLoss'
]
54 changes: 54 additions & 0 deletions mmrazor/models/losses/mgd_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) OpenMMLab. All rights reserved.

import torch
import torch.nn as nn

from mmrazor.registry import MODELS


@MODELS.register_module()
class MGDLoss(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the test case of MGDLoss in test_distillation_losses.
Add the docstring of get_dis_loss.

"""PyTorch version of `Masked Generative Distillation.

<https://arxiv.org/abs/2205.01529>`

Args:
alpha_mgd (float, optional): Weight of dis_loss. Defaults to 0.00002
"""

def __init__(self, alpha_mgd: float = 0.00002) -> None:
super(MGDLoss, self).__init__()
self.alpha_mgd = alpha_mgd
self.loss_mse = nn.MSELoss(reduction='sum')

def forward(self, preds_S: torch.Tensor,
preds_T: torch.Tensor) -> torch.Tensor:
"""Forward function.

Args:
preds_S(torch.Tensor): Bs*C*H*W, student's feature map
preds_T(torch.Tensor): Bs*C*H*W, teacher's feature map

Return:
torch.Tensor: The calculated loss value.
"""
assert preds_S.shape == preds_T.shape
loss = self.get_dis_loss(preds_S, preds_T) * self.alpha_mgd

return loss

def get_dis_loss(self, preds_S: torch.Tensor,
preds_T: torch.Tensor) -> torch.Tensor:
"""Get MSE distance of preds_S and preds_T.

Args:
preds_S(torch.Tensor): Bs*C*H*W, student's feature map
preds_T(torch.Tensor): Bs*C*H*W, teacher's feature map

Return:
torch.Tensor: The calculated mse distance value.
"""
N, C, H, W = preds_T.shape
dis_loss = self.loss_mse(preds_S, preds_T) / N

return dis_loss
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

from mmrazor.models import (BYOTConnector, ConvModuleConncetor, CRDConnector,
FBKDStudentConnector, FBKDTeacherConnector,
Paraphraser, TorchFunctionalConnector,
TorchNNConnector, Translator)
MGDConnector, Paraphraser,
TorchFunctionalConnector, TorchNNConnector,
Translator)


class TestConnector(TestCase):
Expand Down Expand Up @@ -130,3 +131,15 @@ def test_torch_connector(self):
functional_pool_connector = TorchFunctionalConnector()
with self.assertRaises(ValueError):
functional_pool_connector = TorchNNConnector(module_name='fake')

def test_mgd_connector(self):
s_feat = torch.randn(1, 16, 8, 8)
mgd_connector1 = MGDConnector(
student_channels=16, teacher_channels=16, lambda_mgd=0.65)
mgd_connector2 = MGDConnector(
student_channels=16, teacher_channels=32, lambda_mgd=0.65)
s_output1 = mgd_connector1.forward_train(s_feat)
s_output2 = mgd_connector2.forward_train(s_feat)

assert s_output1.shape == torch.Size([1, 16, 8, 8])
assert s_output2.shape == torch.Size([1, 32, 8, 8])
9 changes: 8 additions & 1 deletion tests/test_models/test_losses/test_distillation_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from mmrazor import digit_version
from mmrazor.models import (ABLoss, ActivationLoss, ATLoss, CRDLoss, DKDLoss,
FBKDLoss, FTLoss, InformationEntropyLoss,
KDSoftCELoss, OFDLoss, OnehotLikeLoss, PKDLoss)
KDSoftCELoss, MGDLoss, OFDLoss, OnehotLikeLoss,
PKDLoss)


class TestLosses(TestCase):
Expand Down Expand Up @@ -204,3 +205,9 @@ def test_pkdloss(self):
loss = pkd_loss(feats_S, feats_T)
self.assertTrue(loss.numel() == 1)
self.assertTrue(0. <= loss <= 1.)

def test_mgd_loss(self):
mgd_loss = MGDLoss(alpha_mgd=0.00002)
feats_S, feats_T = torch.rand(2, 256, 4, 4), torch.rand(2, 256, 4, 4)
loss = mgd_loss(feats_S, feats_T)
self.assertTrue(loss.numel() == 1)