-
Notifications
You must be signed in to change notification settings - Fork 231
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #207 from wilxy/dev-1.x
[Feature] Add connector components and FitNet
- Loading branch information
Showing
14 changed files
with
530 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# FitNets | ||
|
||
> [FitNets: Hints for Thin Deep Nets](https://arxiv.org/abs/1412.6550) | ||
<!-- [ALGORITHM] --> | ||
|
||
## Abstract | ||
|
||
While depth tends to improve network performances, it also makes gradient-based | ||
training more difficult since deeper networks tend to be more non-linear. The recently | ||
proposed knowledge distillation approach is aimed at obtaining small and fast-to-execute | ||
models, and it has shown that a student network could imitate the soft output of a larger | ||
teacher network or ensemble of networks. In this paper, we extend this idea to allow the | ||
training of a student that is deeper and thinner than the teacher, using not only the outputs | ||
but also the intermediate representations learned by the teacher as hints to improve the | ||
training process and final performance of the student. Because the student intermediate hidden | ||
layer will generally be smaller than the teacher's intermediate hidden layer, additional parameters | ||
are introduced to map the student hidden layer to the prediction of the teacher hidden layer. This | ||
allows one to train deeper students that can generalize better or run faster, a trade-off that is | ||
controlled by the chosen student capacity. For example, on CIFAR-10, a deep student network with | ||
almost 10.4 times less parameters outperforms a larger, state-of-the-art teacher network. | ||
|
||
![pipeline](/docs/en/imgs/model_zoo/fitnet/pipeline.png) | ||
|
||
## Results and models | ||
|
||
### Classification | ||
|
||
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download | | ||
| :---------------: | :------: | :----------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :----------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------- | | ||
| backbone & logits | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | 70.85 | 76.55 | 69.90 | [config](./fitnet_backbone_logits_resnet50_resnet18_8xb16_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](<>) \| [log](<>) | | ||
|
||
## Citation | ||
|
||
```latex | ||
@inproceedings{DBLP:journals/corr/RomeroBKCGB14, | ||
author = {Adriana Romero, Nicolas Ballas, Samira Ebrahimi Kahou, Antoine Chassang, Carlo Gatta and Yoshua Bengio}, | ||
editor = {Yoshua Bengio and Yann LeCun}, | ||
title = {FitNets: Hints for Thin Deep Nets}, | ||
booktitle = {3rd International Conference on Learning Representations, {ICLR} 2015, | ||
San Diego, CA, USA, May 7-9, 2015, Conference Track Proceedings}, | ||
year = {2015}, | ||
url = {http://arxiv.org/abs/1412.6550}, | ||
timestamp = {Thu, 25 Jul 2019 14:25:38 +0200}, | ||
biburl = {https://dblp.org/rec/journals/corr/RomeroBKCGB14.bib}, | ||
bibsource = {dblp computer science bibliography, https://dblp.org} | ||
} | ||
``` |
71 changes: 71 additions & 0 deletions
71
configs/distill/mmcls/fitnet/fitnet_backbone_logits_resnet50_resnet18_8xb32_in1k.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
_base_ = [ | ||
'mmcls::_base_/datasets/imagenet_bs32.py', | ||
'mmcls::_base_/schedules/imagenet_bs256.py', | ||
'mmcls::_base_/default_runtime.py' | ||
] | ||
|
||
model = dict( | ||
_scope_='mmrazor', | ||
type='SingleTeacherDistill', | ||
data_preprocessor=dict( | ||
type='ImgDataPreprocessor', | ||
# RGB format normalization parameters | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
# convert image from BGR to RGB | ||
bgr_to_rgb=True), | ||
architecture=dict( | ||
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False), | ||
teacher=dict( | ||
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=True), | ||
teacher_ckpt='resnet50_8xb32_in1k_20210831-ea4938fc.pth', | ||
distiller=dict( | ||
type='ConfigurableDistiller', | ||
student_recorders=dict( | ||
bb_s4=dict(type='ModuleOutputs', source='backbone.layer4.1.relu'), | ||
bb_s3=dict(type='ModuleOutputs', source='backbone.layer3.1.relu'), | ||
fc=dict(type='ModuleOutputs', source='head.fc')), | ||
teacher_recorders=dict( | ||
bb_s4=dict(type='ModuleOutputs', source='backbone.layer4.2.relu'), | ||
bb_s3=dict(type='ModuleOutputs', source='backbone.layer3.5.relu'), | ||
fc=dict(type='ModuleOutputs', source='head.fc')), | ||
distill_losses=dict( | ||
loss_s4=dict(type='L2Loss', loss_weight=10), | ||
loss_s3=dict(type='L2Loss', loss_weight=10), | ||
loss_kl=dict( | ||
type='KLDivergence', tau=6, loss_weight=10, reduction='mean')), | ||
connectors=dict( | ||
loss_s4_sfeat=dict( | ||
type='ConvBNReLUConnector', | ||
in_channel=512, | ||
out_channel=2048, | ||
norm_cfg=dict(type='BN')), | ||
loss_s3_sfeat=dict( | ||
type='ConvBNReLUConnector', | ||
in_channel=256, | ||
out_channel=1024, | ||
norm_cfg=dict(type='BN'))), | ||
loss_forward_mappings=dict( | ||
loss_s4=dict( | ||
s_feature=dict( | ||
from_student=True, | ||
recorder='bb_s4', | ||
record_idx=1, | ||
connector='loss_s4_sfeat'), | ||
t_feature=dict( | ||
from_student=False, recorder='bb_s4', record_idx=2)), | ||
loss_s3=dict( | ||
s_feature=dict( | ||
from_student=True, | ||
recorder='bb_s3', | ||
record_idx=1, | ||
connector='loss_s3_sfeat'), | ||
t_feature=dict( | ||
from_student=False, recorder='bb_s3', record_idx=2)), | ||
loss_kl=dict( | ||
preds_S=dict(from_student=True, recorder='fc'), | ||
preds_T=dict(from_student=False, recorder='fc'))))) | ||
|
||
find_unused_parameters = True | ||
|
||
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop') |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .backbones import * # noqa: F401,F403 | ||
from .connectors import * # noqa: F401,F403 | ||
from .dynamic_op import * # noqa: F401,F403 | ||
from .heads import * # noqa: F401,F403 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .general_connector import (ConvBNConnector, ConvBNReLUConnector, | ||
SingleConvConnector) | ||
|
||
__all__ = ['ConvBNConnector', 'ConvBNReLUConnector', 'SingleConvConnector'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from abc import ABCMeta, abstractmethod | ||
from typing import Dict, Optional | ||
|
||
import torch | ||
from mmcv.runner import BaseModule | ||
|
||
|
||
class BaseConnector(BaseModule, metaclass=ABCMeta): | ||
"""Base class of connectors. | ||
Connector is mainly used for distillation, it usually converts the channel | ||
number of input feature to align features of student and teacher. | ||
All subclasses should implement the following APIs: | ||
- ``forward_train()`` | ||
Args: | ||
init_cfg (dict, optional): The config to control the initialization. | ||
""" | ||
|
||
def __init__(self, init_cfg: Optional[Dict] = None) -> None: | ||
super().__init__(init_cfg=init_cfg) | ||
|
||
def forward(self, feature: torch.Tensor) -> None: | ||
"""Forward computation. | ||
Args: | ||
feature (torch.Tensor): Input feature. | ||
""" | ||
return self.forward_train(feature) | ||
|
||
@abstractmethod | ||
def forward_train(self, feature) -> torch.Tensor: | ||
"""Abstract train computation. | ||
Args: | ||
feature (torch.Tensor): Input feature. | ||
""" | ||
pass |
Oops, something went wrong.