by Yuki M. Asano* and Aaqib Saeed* (*Equal Contribution)
Extrapolating from one image. Strongly augmented patches from a single image are used to train a student (S) to distinguish semantic classes, such as those in ImageNet. The student neural network is initialized randomly and learns from a pretrained teacher (T) via KL-divergence. Although almost none of target categories are present in the image, we find student performances of >66% Top-1 Acc for classifying ImageNet's 1000 classes. In this paper, we develop this single datum learning framework and investigate it across datasets and domains.
- A minimal framework for training neural networks with a single datum from scratch using distillation.
- Extensive ablations of the proposed method, such as the dependency on the source image, the choice of augmentations and network architectures.
- Large scale empirical evidence of neural networks' ability to extrapolate on > 13 image, video and audio datasets.
- Qualitative insights on what and how neural networks trained with a single image learn.
We compare activation-maximization-based visualizations using the Lucent library. Even though the model has never seen an image of a panda, the model trained with a teacher and only single-image inputs has a good idea of how a panda looks like.
In each folder cifar\in1k\video
you will find a requirements.txt file. Install packages as follows:
pip3 install -r requirements.txt
To generate single image data, we refer to the data_generation folder
There is a main "distill.py" file for each experiment type: small-scale and large-scale images and video. Note: 2a uses tensorflow and 2b, 2c use pytorch.
e.g. with Animal single-image dataset as follows:
# in cifar folder:
python3 distill.py --dataset=cifar10 --image=/path/to/single_image_dataset/ \
--student=wrn_16_4 --teacher=wrn_40_4
Note that we provide a pretrained teacher model for reproducibility.
# in in1k folder:
python3 distill.py --dataset=in1k --testdir /ILSVRC12/val/ \
--traindir=/path/to/dataset/ --student_arch=resnet50 --teacher_arch=resnet18
Note that teacher models are automatically downloaded from torchvision or timm.
# in video folder:
python3 distill.py --dataset=k400 --traindir=/dataset/with/vids --test_data_path /path/to/k400/val
Note that teacher models are automatically downloaded from torchvideo when you distill a K400 model.
Large-scale (224x224-sized) image ResNet-50 models trained for 200ep:
Dataset | Teacher | Student | Performance | Checkpoint |
---|---|---|---|---|
ImageNet-12 | R18 | R50 | 66.2% | R50 weights |
ImageNet-12 | R50 | R50 | 55.5% | R50 weights |
Places365 | R18 | R50 | 50.3% | R50 weights |
Flowers101 | R18 | R50 | 81.5% | R50 weights |
Pets37 | R18 | R50 | 76.8% | R50 weights |
IN100 | R18 | R50 | 66.2% | R50 weights |
STL-10 | R18 | R50 | 93.9% | R50 weights |
Video x3d_s_e (expanded) models (160x160 crop, 4frames) trained for 400ep:
Dataset | Teacher | Student | Performance | Checkpoint |
---|---|---|---|---|
K400 | x3d_xs | x3d_xs_e | 51.8% | weights |
UCF101 | x3d_xs | x3d_xs_e | 75.2% | weights |
@inproceedings{asano2023augmented,
title={The Augmented Image Prior: Distilling 1000 Classes by Extrapolating from a Single Image},
author={Asano, Yuki M. and Saeed, Aaqib},
journal={ICLR},
year={2023}
}