Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add connector components and FitNet #207

Merged
merged 6 commits into from
Jul 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions configs/distill/mmcls/fitnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# FitNets

> [FitNets: Hints for Thin Deep Nets](https://arxiv.org/abs/1412.6550)

<!-- [ALGORITHM] -->

## Abstract

While depth tends to improve network performances, it also makes gradient-based
training more difficult since deeper networks tend to be more non-linear. The recently
proposed knowledge distillation approach is aimed at obtaining small and fast-to-execute
models, and it has shown that a student network could imitate the soft output of a larger
teacher network or ensemble of networks. In this paper, we extend this idea to allow the
training of a student that is deeper and thinner than the teacher, using not only the outputs
but also the intermediate representations learned by the teacher as hints to improve the
training process and final performance of the student. Because the student intermediate hidden
layer will generally be smaller than the teacher's intermediate hidden layer, additional parameters
are introduced to map the student hidden layer to the prediction of the teacher hidden layer. This
allows one to train deeper students that can generalize better or run faster, a trade-off that is
controlled by the chosen student capacity. For example, on CIFAR-10, a deep student network with
almost 10.4 times less parameters outperforms a larger, state-of-the-art teacher network.

![pipeline](/docs/en/imgs/model_zoo/fitnet/pipeline.png)

## Results and models

### Classification

| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :---------------: | :------: | :----------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :----------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------- |
| backbone & logits | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | 70.85 | 76.55 | 69.90 | [config](./fitnet_backbone_logits_resnet50_resnet18_8xb16_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](<>) \| [log](<>) |

## Citation

```latex
@inproceedings{DBLP:journals/corr/RomeroBKCGB14,
author = {Adriana Romero, Nicolas Ballas, Samira Ebrahimi Kahou, Antoine Chassang, Carlo Gatta and Yoshua Bengio},
editor = {Yoshua Bengio and Yann LeCun},
title = {FitNets: Hints for Thin Deep Nets},
booktitle = {3rd International Conference on Learning Representations, {ICLR} 2015,
San Diego, CA, USA, May 7-9, 2015, Conference Track Proceedings},
year = {2015},
url = {http://arxiv.org/abs/1412.6550},
timestamp = {Thu, 25 Jul 2019 14:25:38 +0200},
biburl = {https://dblp.org/rec/journals/corr/RomeroBKCGB14.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
_base_ = [
'mmcls::_base_/datasets/imagenet_bs32.py',
'mmcls::_base_/schedules/imagenet_bs256.py',
'mmcls::_base_/default_runtime.py'
]

model = dict(
_scope_='mmrazor',
type='SingleTeacherDistill',
data_preprocessor=dict(
type='ImgDataPreprocessor',
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
bgr_to_rgb=True),
architecture=dict(
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
teacher=dict(
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=True),
teacher_ckpt='resnet50_8xb32_in1k_20210831-ea4938fc.pth',
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
bb_s4=dict(type='ModuleOutputs', source='backbone.layer4.1.relu'),
bb_s3=dict(type='ModuleOutputs', source='backbone.layer3.1.relu'),
fc=dict(type='ModuleOutputs', source='head.fc')),
teacher_recorders=dict(
bb_s4=dict(type='ModuleOutputs', source='backbone.layer4.2.relu'),
bb_s3=dict(type='ModuleOutputs', source='backbone.layer3.5.relu'),
fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict(
loss_s4=dict(type='L2Loss', loss_weight=10),
loss_s3=dict(type='L2Loss', loss_weight=10),
loss_kl=dict(
type='KLDivergence', tau=6, loss_weight=10, reduction='mean')),
connectors=dict(
loss_s4_sfeat=dict(
type='ConvBNReLUConnector',
in_channel=512,
out_channel=2048,
norm_cfg=dict(type='BN')),
loss_s3_sfeat=dict(
type='ConvBNReLUConnector',
in_channel=256,
out_channel=1024,
norm_cfg=dict(type='BN'))),
loss_forward_mappings=dict(
loss_s4=dict(
s_feature=dict(
from_student=True,
recorder='bb_s4',
record_idx=1,
connector='loss_s4_sfeat'),
t_feature=dict(
from_student=False, recorder='bb_s4', record_idx=2)),
loss_s3=dict(
s_feature=dict(
from_student=True,
recorder='bb_s3',
record_idx=1,
connector='loss_s3_sfeat'),
t_feature=dict(
from_student=False, recorder='bb_s3', record_idx=2)),
loss_kl=dict(
preds_S=dict(from_student=True, recorder='fc'),
preds_T=dict(from_student=False, recorder='fc')))))

find_unused_parameters = True

val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
Binary file added docs/en/imgs/model_zoo/fitnet/pipeline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from mmcv.runner import load_checkpoint
from mmengine import BaseDataElement
from mmengine.model import BaseModel
from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm

from mmrazor.models.utils import add_prefix
Expand All @@ -18,6 +19,7 @@ class SingleTeacherDistill(BaseAlgorithm):
only use one teacher.

Args:
distiller (dict): The config dict for built distiller.
teacher (dict | BaseModel): The config dict for teacher model or built
teacher model.
teacher_ckpt (str): The path of teacher's checkpoint. Defaults to None.
Expand All @@ -26,6 +28,10 @@ class SingleTeacherDistill(BaseAlgorithm):
teacher_norm_eval (bool): Whether to set teacher's norm layers to eval
mode, namely, freeze running stats (mean and var). Note: Effect on
Batch Norm and its variants only. Defaults to True.
student_trainable (bool): Whether the student is trainable. Defaults
to True.
calculate_student_loss (bool): Whether to calculate student loss
(original task loss) to update student model. Defaults to True.
"""

def __init__(self,
Expand All @@ -34,7 +40,9 @@ def __init__(self,
teacher_ckpt: Optional[str] = None,
teacher_trainable: bool = False,
teacher_norm_eval: bool = True,
**kwargs):
student_trainable: bool = True,
calculate_student_loss: bool = True,
**kwargs) -> None:
super().__init__(**kwargs)

self.distiller = MODELS.build(distiller)
Expand All @@ -55,13 +63,21 @@ def __init__(self,
self.teacher_trainable = teacher_trainable
self.teacher_norm_eval = teacher_norm_eval

# The student model will not calculate gradients and update parameters
# in some pretraining process.
self.student_trainable = student_trainable

# The student loss will not be updated into ``losses`` in some
# pretraining process.
self.calculate_student_loss = calculate_student_loss

# In ``ConfigurableDistller``, the recorder manager is just
# constructed, but not really initialized yet.
self.distiller.prepare_from_student(self.student)
self.distiller.prepare_from_teacher(self.teacher)

@property
def student(self):
def student(self) -> nn.Module:
"""Alias for ``architecture``."""
return self.architecture

Expand All @@ -86,16 +102,25 @@ def loss(
else:
with self.distiller.teacher_recorders, self.distiller.deliveries:
with torch.no_grad():

_ = self.teacher(batch_inputs, data_samples, mode='loss')

# If the `override_data` of a delivery is True, the delivery will
# override the origin data with the recorded data.
self.distiller.set_deliveries_override(True)
with self.distiller.student_recorders, self.distiller.deliveries:
student_losses = self.student(
batch_inputs, data_samples, mode='loss')
losses.update(add_prefix(student_losses, 'student'))
# Original task loss will not be used during some pretraining process.
if self.calculate_student_loss:
with self.distiller.student_recorders, self.distiller.deliveries:
student_losses = self.student(
batch_inputs, data_samples, mode='loss')
losses.update(add_prefix(student_losses, 'student'))
else:
with self.distiller.student_recorders, self.distiller.deliveries:
if self.student_trainable:
_ = self.student(batch_inputs, data_samples, mode='loss')
else:
with torch.no_grad():
_ = self.student(
batch_inputs, data_samples, mode='loss')

# Automatically compute distill losses based on `loss_forward_mappings`
# The required data already exists in the recorders.
Expand All @@ -104,7 +129,7 @@ def loss(

return losses

def train(self, mode=True):
def train(self, mode: bool = True) -> None:
"""Set distiller's forward mode."""
super().train(mode)
if mode and self.teacher_norm_eval:
Expand Down
1 change: 1 addition & 0 deletions mmrazor/models/architectures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .backbones import * # noqa: F401,F403
from .connectors import * # noqa: F401,F403
from .dynamic_op import * # noqa: F401,F403
from .heads import * # noqa: F401,F403
5 changes: 5 additions & 0 deletions mmrazor/models/architectures/connectors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .general_connector import (ConvBNConnector, ConvBNReLUConnector,
SingleConvConnector)

__all__ = ['ConvBNConnector', 'ConvBNReLUConnector', 'SingleConvConnector']
41 changes: 41 additions & 0 deletions mmrazor/models/architectures/connectors/base_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Dict, Optional

import torch
from mmcv.runner import BaseModule


class BaseConnector(BaseModule, metaclass=ABCMeta):
"""Base class of connectors.

Connector is mainly used for distillation, it usually converts the channel
number of input feature to align features of student and teacher.

All subclasses should implement the following APIs:

- ``forward_train()``

Args:
init_cfg (dict, optional): The config to control the initialization.
"""

def __init__(self, init_cfg: Optional[Dict] = None) -> None:
super().__init__(init_cfg=init_cfg)

def forward(self, feature: torch.Tensor) -> None:
"""Forward computation.

Args:
feature (torch.Tensor): Input feature.
"""
return self.forward_train(feature)

@abstractmethod
def forward_train(self, feature) -> torch.Tensor:
"""Abstract train computation.

Args:
feature (torch.Tensor): Input feature.
"""
pass
Loading