-
Notifications
You must be signed in to change notification settings - Fork 231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add connector components and FitNet #207
Changes from 1 commit
6ee68fe
1b17505
f226c9e
a8fb03e
cbcd8a3
d75d120
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# FitNets | ||
|
||
> [FitNets: Hints for Thin Deep Nets](https://arxiv.org/abs/1412.6550) | ||
<!-- [ALGORITHM] --> | ||
|
||
## Abstract | ||
|
||
While depth tends to improve network performances, it also makes gradient-based | ||
training more difficult since deeper networks tend to be more non-linear. The recently | ||
proposed knowledge distillation approach is aimed at obtaining small and fast-to-execute | ||
models, and it has shown that a student network could imitate the soft output of a larger | ||
teacher network or ensemble of networks. In this paper, we extend this idea to allow the | ||
training of a student that is deeper and thinner than the teacher, using not only the outputs | ||
but also the intermediate representations learned by the teacher as hints to improve the | ||
training process and final performance of the student. Because the student intermediate hidden | ||
layer will generally be smaller than the teacher's intermediate hidden layer, additional parameters | ||
are introduced to map the student hidden layer to the prediction of the teacher hidden layer. This | ||
allows one to train deeper students that can generalize better or run faster, a trade-off that is | ||
controlled by the chosen student capacity. For example, on CIFAR-10, a deep student network with | ||
almost 10.4 times less parameters outperforms a larger, state-of-the-art teacher network. | ||
|
||
![pipeline](/docs/en/imgs/model_zoo/fitnet/pipeline.png) | ||
|
||
## Results and models | ||
|
||
### Classification | ||
|
||
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download | | ||
| :---------------: | :------: | :----------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :----------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------- | | ||
| backbone & logits | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | 70.85 | 76.55 | 69.90 | [config](./fitnet_backbone_logits_resnet50_resnet18_8xb16_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](<>) \| [log](<>) | | ||
|
||
## Citation | ||
|
||
```latex | ||
@inproceedings{DBLP:journals/corr/RomeroBKCGB14, | ||
author = {Adriana Romero, Nicolas Ballas, Samira Ebrahimi Kahou, Antoine Chassang, Carlo Gatta and Yoshua Bengio}, | ||
editor = {Yoshua Bengio and Yann LeCun}, | ||
title = {FitNets: Hints for Thin Deep Nets}, | ||
booktitle = {3rd International Conference on Learning Representations, {ICLR} 2015, | ||
San Diego, CA, USA, May 7-9, 2015, Conference Track Proceedings}, | ||
year = {2015}, | ||
url = {http://arxiv.org/abs/1412.6550}, | ||
timestamp = {Thu, 25 Jul 2019 14:25:38 +0200}, | ||
biburl = {https://dblp.org/rec/journals/corr/RomeroBKCGB14.bib}, | ||
bibsource = {dblp computer science bibliography, https://dblp.org} | ||
} | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
_base_ = [ | ||
'mmcls::_base_/datasets/imagenet_bs32.py', | ||
'mmcls::_base_/schedules/imagenet_bs256.py', | ||
'mmcls::_base_/default_runtime.py' | ||
] | ||
|
||
model = dict( | ||
_scope_='mmrazor', | ||
type='SingleTeacherDistill', | ||
data_preprocessor=dict( | ||
type='ImgDataPreprocessor', | ||
# RGB format normalization parameters | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
# convert image from BGR to RGB | ||
bgr_to_rgb=True), | ||
architecture=dict( | ||
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False), | ||
teacher=dict( | ||
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=True), | ||
teacher_ckpt='resnet34_8xb32_in1k_20210831-f257d4e6.pth', | ||
distiller=dict( | ||
type='ConfigurableDistiller', | ||
student_recorders=dict( | ||
feat_4=dict(type='ModuleOutputs', source='backbone.layer4.1.relu'), | ||
feat_3=dict(type='ModuleOutputs', source='backbone.layer3.1.relu'), | ||
fc=dict(type='ModuleOutputs', source='head.fc')), | ||
teacher_recorders=dict( | ||
feat_4=dict(type='ModuleOutputs', source='backbone.layer4.2.relu'), | ||
feat_3=dict(type='ModuleOutputs', source='backbone.layer3.5.relu'), | ||
fc=dict(type='ModuleOutputs', source='head.fc')), | ||
distill_losses=dict( | ||
loss_f4=dict(type='L2Loss', loss_weight=10), | ||
loss_f3=dict(type='L2Loss', loss_weight=10), | ||
loss_kl=dict( | ||
type='KLDivergence', tau=6, loss_weight=10, reduction='mean')), | ||
student_connectors=dict( | ||
loss_f4=dict( | ||
type='ReLUConnector', in_channel=512, out_channel=2048), | ||
loss_f3=dict( | ||
type='ReLUConnector', in_channel=256, out_channel=1024)), | ||
loss_forward_mappings=dict( | ||
loss_f4=dict( | ||
s_feature=dict( | ||
from_student=True, recorder='feat_4', record_idx=1), | ||
t_feature=dict( | ||
from_student=False, recorder='feat_4', record_idx=2)), | ||
loss_f3=dict( | ||
s_feature=dict( | ||
from_student=True, recorder='feat_3', record_idx=1), | ||
t_feature=dict( | ||
from_student=False, recorder='feat_3', record_idx=2)), | ||
loss_kl=dict( | ||
preds_S=dict(from_student=True, recorder='fc'), | ||
preds_T=dict(from_student=False, recorder='fc'))))) | ||
|
||
find_unused_parameters = True | ||
|
||
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop') |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .general_connector import BNConnector, ReLUConnector, SingleConvConnector | ||
|
||
__all__ = ['BNConnector', 'ReLUConnector', 'SingleConvConnector'] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from abc import ABCMeta, abstractmethod | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class BaseConnector(nn.Module, metaclass=ABCMeta): | ||
"""Base class of connectors. | ||
|
||
Connector is mainly used for distill, it usually converts the channel | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. distill -> distillation |
||
number of input feature to align features of student and teacher. | ||
|
||
All subclasses should implement the following APIs: | ||
|
||
- ``forward_train()`` | ||
""" | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
def forward(self, feature: torch.Tensor) -> None: | ||
"""Forward computation. | ||
|
||
Args: | ||
feature (torch.Tensor): Input feature. | ||
""" | ||
return self.forward_train(feature) | ||
|
||
@abstractmethod | ||
def forward_train(self, feature) -> torch.Tensor: | ||
"""Abstract train computation. | ||
|
||
Args: | ||
feature (torch.Tensor): Input feature. | ||
""" | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
|
||
from mmrazor.registry import MODELS | ||
from .base_connector import BaseConnector | ||
|
||
|
||
@MODELS.register_module() | ||
class SingleConvConnector(BaseConnector): | ||
"""General connector which only contains a conv layer. | ||
|
||
Args: | ||
in_channel (int): The input channel of the connector. | ||
out_channel (int): The output channel of the connector. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
in_channel: int, | ||
out_channel: int, | ||
) -> None: | ||
super().__init__() | ||
self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||
self.init_parameters() | ||
|
||
def forward_train(self, feature: torch.Tensor) -> torch.Tensor: | ||
"""Forward computation. | ||
|
||
Args: | ||
feature (torch.Tensor): Input feature. | ||
""" | ||
return self.conv(feature) | ||
|
||
def init_parameters(self) -> None: | ||
"""Init parameters.""" | ||
with torch.no_grad(): | ||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
device = m.weight.device | ||
in_channels, _, k1, k2 = m.weight.shape | ||
m.weight[:] = torch.randn( | ||
m.weight.shape, device=device) / np.sqrt( | ||
k1 * k2 * in_channels) * 1e-4 | ||
if hasattr(m, 'bias') and m.bias is not None: | ||
nn.init.zeros_(m.bias) | ||
else: | ||
continue | ||
|
||
|
||
@MODELS.register_module() | ||
class BNConnector(BaseConnector): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. -> ConvBNConnnector |
||
"""General connector which contains a conv layer with BN. | ||
|
||
Args: | ||
in_channel (int): The input channels of the connector. | ||
out_channel (int): The output channels of the connector. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
in_channel: int, | ||
out_channel: int, | ||
) -> None: | ||
super().__init__() | ||
self.conv = nn.Conv2d( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||
in_channel, | ||
out_channel, | ||
kernel_size=1, | ||
stride=1, | ||
padding=0, | ||
bias=False) | ||
self.bn = nn.BatchNorm2d(out_channel) | ||
|
||
def forward_train(self, feature: torch.Tensor) -> torch.Tensor: | ||
"""Forward computation. | ||
|
||
Args: | ||
feature (torch.Tensor): Input feature. | ||
""" | ||
return self.bn(self.conv(feature)) | ||
|
||
|
||
@MODELS.register_module() | ||
class ReLUConnector(BaseConnector): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ->ConvBNReLUConnector? |
||
"""General connector which contains a conv layer with BN and ReLU. | ||
|
||
Args: | ||
in_channel (int): The input channels of the connector. | ||
out_channel (int): The output channels of the connector. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
in_channel: int, | ||
out_channel: int, | ||
) -> None: | ||
super().__init__() | ||
self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The same |
||
self.bn = nn.BatchNorm2d(out_channel) | ||
self.relu = nn.ReLU(inplace=True) | ||
|
||
def forward_train(self, feature: torch.Tensor) -> torch.Tensor: | ||
"""Forward computation. | ||
|
||
Args: | ||
feature (torch.Tensor): Input feature. | ||
""" | ||
return self.relu(self.bn(self.conv(feature))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to inherit from BaseModule?
Then,
self.init_parameters()
could be rewritten byinit_weights()
andinit_cfg