News! I'm going to have a major update of this repo. The new version will contain most of the methods in Todo list. Please stay tuned.
Pytorch implementation of various Knowledge Distillation (KD) methods.
This repository is a simple reference, mainly focuses on basic knowledge distillation/transfer methods. Thus many tricks and variations, such as step-by-step training, iterative training, ensemble of teachers, ensemble of KD methods, data-free, self-distillation, online distillation etc. are not considered. Hope it is useful for your project or research.
I will update this repo regularly with new KD methods. If there some basic methods I missed, please contact with me.
Name | Method | Paper Link | Code Link |
---|---|---|---|
Baseline | basic model with softmax loss | — | code |
Logits | mimic learning via regressing logits | paper | code |
ST | soft target | paper | code |
AT | attention transfer | paper | code |
Fitnet | hints for thin deep nets | paper | code |
NST | neural selective transfer | paper | code |
PKT | probabilistic knowledge transfer | paper | code |
FSP | flow of solution procedure | paper | code |
FT | factor transfer | paper | code |
RKD | relational knowledge distillation | paper | code |
AB | activation boundary | paper | code |
SP | similarity preservation | paper | code |
Sobolev | sobolev/jacobian matching | paper | code |
BSS | boundary supporting samples | paper | code |
CC | correlation congruence | paper | code |
LwM | learning without memorizing | paper | code |
IRG | instance relationship graph | paper | code |
VID | variational information distillation | paper | code |
OFD | overhaul of feature distillation | paper | code |
AFD | attention feature distillation | paper | code |
CRD | contrastive representation distillation | paper | code |
DML | deep mutual learning | paper | code |
- Note, there are some differences between this repository and the original papers:
- For
AT
: I use the sum of absolute values with power p=2 as the attention. - For
Fitnet
: The training procedure is one stage without hint layer. - For
NST
: I employ polynomial kernel with d=2 and c=0. - For
AB
: Two-stage training, the first 50 epochs for initialization, the second stage only employs CE without ST. - For
BSS
: 80% epochs employ CE+BSS loss, the rest 20% only uses CE. In addition, warmup for the first 10 epochs. - For
CC
: For consistency, I only consider CC without instance congruence. Gaussian RBF kernel is employed because Bilinear Pool kernel is similar with PKT. I choose P=2 order Taylor of Gaussian RBF kernel. No special sampling strategy. - For
LwM
: I employ it after rb2 (middle conv layer) but not rb3 (last conv layer), because the base net is resnet with the end of GAP followed by a classifier. If after rb3, the grad-CAN has the same values across H and W in each channel. - For
IRG
: I only use one-to-one mode. - For
VID
: I set the hidden channel size to be same with the output channel size and remove BN in μ. - For
AFD
: I find the original implementation of attention is unstable, thus replace it with a SE block. - For
DML
: Just two nets are employed. Synchronous update to avoid multiple forwards.
- For
- CIFAR10
- CIFAR100
- Resnet-20
- Resnet-110
The networks are same with Tabel 6 in paper.
- Creating
./dataset
directory and downloading CIFAR10/CIFAR100 in it. - Using the script
example_train_script.sh
to train various KD methods. You can simply specify the hyper-parameters listed intrain_xxx.py
or manually change them. - The hyper-parameters I used can be found in the training logs (code: ezed).
- Some Notes:
- Sobolev/LwM alone is unstable and may be used in conjunction with other KD methods.
- BSS may occasionally destroy the training procedure, leading to poor results.
- If not specified in the original papers, all the methods can be used on the middle feature maps or multiple feature maps are only employed after the last conv layer. It is simple to extend to multiple feature maps.
- I assume the size (C, H, W) of features between teacher and student are the same. If not, you could employ 1*1 conv, linear or pooling to rectify them.
- The trained baseline models are used as teachers. For fair comparison, all the student nets have same initialization with the baseline models.
- The initial models, trained models and training logs are uploaded here (code: ezed).
- The trade-off parameter
--lambda_kd
and other hyper-parameters are not chosen carefully. Thus the following results do not reflect which method is better than the others. - Some relation based methods, e.g. PKT, RKD and CC, have less effectiveness on CIFAR100 dataset. It may be because there are more inter classes but less intra classes in one batch. You could increase the batch size, create memory bank or design advance batch sampling methods.
Teacher | Student | Name | CIFAR10 | CIFAR100 |
- | resnet-20 | Baseline | 92.37% | 68.92% |
resnet-20 | resnet-20 | Logits | 93.30% | 70.36% |
resnet-20 | resnet-20 | ST | 93.12% | 70.27% |
resnet-20 | resnet-20 | AT | 92.89% | 69.70% |
resnet-20 | resnet-20 | Fitnet | 92.73% | 70.08% |
resnet-20 | resnet-20 | NST | 92.79% | 69.21% |
resnet-20 | resnet-20 | PKT | 92.50% | 69.25% |
resnet-20 | resnet-20 | FSP | 92.76% | 69.61% |
resnet-20 | resnet-20 | FT | 92.98% | 69.90% |
resnet-20 | resnet-20 | RKD | 92.72% | 69.48% |
resnet-20 | resnet-20 | AB | 93.04% | 69.96% |
resnet-20 | resnet-20 | SP | 92.88% | 69.85% |
resnet-20 | resnet-20 | Sobolev | 92.78% | 69.39% |
resnet-20 | resnet-20 | BSS | 92.58% | 69.96% |
resnet-20 | resnet-20 | CC | 93.01% | 69.27% |
resnet-20 | resnet-20 | LwM | 92.80% | 69.23% |
resnet-20 | resnet-20 | IRG | 92.77% | 70.37% |
resnet-20 | resnet-20 | VID | 92.61% | 69.39% |
resnet-20 | resnet-20 | OFD | 92.82% | 69.93% |
resnet-20 | resnet-20 | AFD | 92.56% | 69.63% |
resnet-20 | resnet-20 | CRD | 92.96% | 70.33% |
Teacher | Student | Name | CIFAR10 | CIFAR100 |
- | resnet-20 | Baseline | 92.37% | 68.92% |
- | resnet-110 | Baseline | 93.86% | 73.15% |
resnet-110 | resnet-20 | Logits | 92.98% | 69.78% |
resnet-110 | resnet-20 | ST | 92.82% | 70.06% |
resnet-110 | resnet-20 | AT | 93.21% | 69.28% |
resnet-110 | resnet-20 | Fitnet | 93.04% | 69.81% |
resnet-110 | resnet-20 | NST | 92.83% | 69.31% |
resnet-110 | resnet-20 | PKT | 93.01% | 69.31% |
resnet-110 | resnet-20 | FSP | 92.78% | 69.78% |
resnet-110 | resnet-20 | FT | 93.01% | 69.49% |
resnet-110 | resnet-20 | RKD | 93.21% | 69.36% |
resnet-110 | resnet-20 | AB | 92.96% | 69.41% |
resnet-110 | resnet-20 | SP | 93.30% | 69.45% |
resnet-110 | resnet-20 | Sobolev | 92.60% | 69.23% |
resnet-110 | resnet-20 | BSS | 92.78% | 69.71% |
resnet-110 | resnet-20 | CC | 92.98% | 69.33% |
resnet-110 | resnet-20 | LwM | 92.52% | 69.11% |
resnet-110 | resnet-20 | IRG | 93.13% | 69.36% |
resnet-110 | resnet-20 | VID | 92.98% | 69.49% |
resnet-110 | resnet-20 | OFD | 93.13% | 69.81% |
resnet-110 | resnet-20 | AFD | 92.92% | 69.60% |
resnet-110 | resnet-20 | CRD | 92.92% | 70.80% |
Teacher | Student | Name | CIFAR10 | CIFAR100 |
- | resnet-110 | Baseline | 93.86% | 73.15% |
resnet-110 | resnet-110 | Logits | 94.38% | 74.89% |
resnet-110 | resnet-110 | ST | 94.59% | 74.33% |
resnet-110 | resnet-110 | AT | 94.42% | 74.64% |
resnet-110 | resnet-110 | Fitnet | 94.43% | 73.63% |
resnet-110 | resnet-110 | NST | 94.43% | 73.55% |
resnet-110 | resnet-110 | PKT | 94.35% | 73.74% |
resnet-110 | resnet-110 | FSP | 94.39% | 73.59% |
resnet-110 | resnet-110 | FT | 94.30% | 74.72% |
resnet-110 | resnet-110 | RKD | 94.39% | 73.78% |
resnet-110 | resnet-110 | AB | 94.63% | 73.91% |
resnet-110 | resnet-110 | SP | 94.45% | 74.07% |
resnet-110 | resnet-110 | Sobolev | 94.26% | 73.14% |
resnet-110 | resnet-110 | BSS | 94.19% | 73.87% |
resnet-110 | resnet-110 | CC | 94.49% | 74.43% |
resnet-110 | resnet-110 | LwM | 94.19% | 73.28% |
resnet-110 | resnet-110 | IRG | 94.44% | 74.96% |
resnet-110 | resnet-110 | VID | 94.25% | 73.63% |
resnet-110 | resnet-110 | OFD | 94.38% | 74.11% |
resnet-110 | resnet-110 | AFD | 94.44% | 73.90% |
resnet-110 | resnet-110 | CRD | 94.30% | 75.44% |
Net1 | Net2 | Name | CIFAR10 | CIFAR100 |
- | resnet-20 | baseline | 92.37% | 68.92% |
- | resnet-110 | baseline | 93.86% | 73.15% |
resnet20 | resnet20 | DML | 93.07%/93.37% | 70.39%/70.22% |
resnet110 | resnet20 | DML | 94.45%/92.92% | 74.53%/70.29% |
resnet110 | resnet110 | DML | 94.74%/94.79% | 74.72%/75.55% |
- KDSVD (now has some bugs)
- QuEST: Quantized Embedding Space for Transferring Knowledge
- EEL: Learning an Evolutionary Embedding via Massive Knowledge Distillation
- OnAdvFD: Feature-map-level Online Adversarial Knowledge Distillation
- CS-KD: Regularizing Class-wise Predictions via Self-knowledge Distillation
- PAD: Prime-Aware Adaptive Distillation
- DCM: Knowledge Transfer via Dense Cross-Layer Mutual-Distillation
- ESKD: On the Efficacy of Knowledge Distillation
- GKA: Teachers Do More Than Teach: Compressing Image-to-Image Models
- KD-fn: Feature Normalized Knowledge Distillation for Image Classification
- FSD: Knowledge Distillation via Adaptive Instance Normalization
- CSKD: Improving Knowledge Distillation via Category Structure
- CD & GKD: Channel Distillation: Channel-Wise Attention for Knowledge Distillation
- CRCD: Complementary Relation Contrastive Distillation
- MGD: Matching Guided Distillation
- SSKD: Knowledge Distillation Meets Self-Supervision
- KR: Distilling Knowledge via Knowledge Review
- AFD-SAD: Show, Attend and Distill: Knowledge Distillation via Attention-based Feature Matching
- SRRL: Knowledge Distillation via Softmax Regression Representation Learning
- SemCKD: Cross-Layer Distillation with Semantic Calibration
- SKD: Reducing the Teacher-Student Gap via Spherical Knowledge Disitllation
- IFDM: Heterogeneous Knowledge Distillation using Information Flow Modeling
- LKD: Local Correlation Consistency for Knowledge Distillation
- HKD: Distilling Holistic Knowledge with Graph Neural Networks
- LONDON: Lipschitz Continuity Guided Knowledge Distillation
- CDD: Channel-Wise Knowledge Distillation for Dense Prediction
- SCKD: Student Customized Knowledge Distillation: Bridging the Gap Between Student and Teacher
- ICKD: Exploring Inter-Channel Correlation for Diversity-preserved Knowledge Distillation
- DKMF: Distilling Knowledge by Mimicking Features
- CRCD: Complementary Relation Contrastive Distillation
- WCoRD: Wasserstein Contrastive Representation Distillation
- LKD: Local Correlation Consistency for Knowledge Distillation
- KDCL: Online Knowledge Distillation via Collaborative Learning
- ONE: Knowledge Distillation by On-the-Fly Native Ensemble
- python 3.7
- pytorch 1.3.1
- torchvision 0.4.2
This repo is partly based on the following repos, thank the authors a lot.
- HobbitLong/RepDistiller
- bhheo/BSS_distillation
- clovaai/overhaul-distillation
- passalis/probabilistic_kt
- lenscloth/RKD
If you employ the listed KD methods in your research, please cite the corresponding papers.