This repository open source the code for ViTAS: Vision Transformer Architecture Search. ViTAS aims to search for pure transformer architectures, which do not include CNN convolution or indutive bias related operations.
- torch>=1.4.0
- torchvision
- pymoo==0.3.0 for evaluation --> pip install pymoo==0.3.0 --user
- change the 'data_dir' in yaml from search/retrain/inference directory to your ImageNet data path, note that each yaml have four 'data_dir' for training the supernet (train data), evolutionary sampling with supernet (val data), retraining the searched architecture (train data), and test the trained architecture (test data).
- This code is based on slurm for distributed training.
Illustration of the private class tokens and self-attention maps in the ViTAS. In the private class tokens, for two different patch sizes p, we assign two independent class tokens for each and obtain p patches and a private class token under one patch size setting. In the private self-attention maps, value v is shared among the four cases with the four head numbers, while q and k are obtained independently for all the cases.
We incorporate 2 strategies in ViTAS for searching the optimal width, i.e., BCNet(paper) and AutoSlim(paper).
We strongly recommend using BCNet mode since BCNet can promote the supernet being trained fairly and promote better-searched results. From the experiments of BCNet, with the same searching budgets, BCNet can surpass AutoSlim by 0.8% on Top-1 accuracy.
chmod +x ./script/command.sh
chmod +x ./script/vit_1G_search.sh
./script/vit_1G_search.sh (BCNet mode)
./script/vit_1G_search_AS.sh (AutoSlim mode)
For example, train our 1.3G architecture searched by ViTAS.
chmod +x ./script/command.sh
chmod +x ./script/vit_1.3G_retrain.sh
./script/vit_1.3G_retrain.sh
For example, inference our 1.3G architecture searched by ViTAS.
chmod +x ./script/command.sh
chmod +x ./script/vit_1.3G_inference.sh
./script/vit_1.3G_inference.sh
Although the inspiring results from the ViTAS, the searched ViT architectures are complex and hard to remember for researchers. For the practicality of the ViTAS, we restricted all transformer blocks in a single architecture (i.e., a cell) to have the same structure, including head number and output dimension, with a steady patch size as 16. With this setting, the searched block-level optimal architecture is shown in the below table. With this setting, 1.3G FLOPs block-level architecture can achieve 74.7% on Top-1 accuracy on ImageNet. The searched architecture is shown in below table.
Number | Type | Patch size / #Heads | Output Dim |
---|---|---|---|
1 | Embedding | 16 | 230 |
12 | MHSA | 3 | 432 |
-- | MLP | - | 720 |
In each yaml, the 'save_path' in 'search' controls all paths (eg., line 34 in inference/ViTAS_1.3G_inference.yaml). The code will automatically build the path of 'save_path'+'search/checkpoint/' for your supernet, and also 'save_path' + 'retrain/checkpoint' for retraining the searched architecture.
Therefore, to inference the provided pth file, you need to build a path of 'save_path/retrain/checkpoint/download.pth' ('save_path' is specified in yaml and download.pth is provided in below table).
The extract code for Baidu Cloud is 'c7gn'.
Model name | FLOPs | Top 1 | Top 5 | Download |
---|---|---|---|---|
ViTAS-A | 858M | 71.1% | 89.8% | Google Drive, Baidu Cloud |
ViTAS-B | 1.0G | 72.4% | 90.6% | Google Drive, Baidu Cloud |
ViTAS-C | 1.3G | 74.7% | 92.0% | Google Drive, Baidu Cloud |
ViTAS-E | 2.7G | 77.4% | 93.8% | Google Drive, Baidu Cloud |
ViTAS-F | 4.9G | 80.6% | 95.1% | Google Drive, Baidu Cloud |
For a fair comparison of Deit and ViT architectures, we also provided their results in below table:
Model name | FLOPs | Top 1 | Top 5 |
---|---|---|---|
DeiT-Ti | 1.3G | 72.2 | 80.1 |
DeiT-S | 4.6G | 79.8 | 85.7 |
If you find that ViTAS interesting and help your research, please consider citing it:
@article{vision,
title={Vision Transformer Architecture Search},
author={Su, Xiu and You, Shan and Xie, Jiyang and Zheng, Mingkai and Wang, Fei and Qian, Chen and Zhang, Changshui and Wang, Xiaogang and Xu, Chang},
journal={arXiv preprint arXiv:2106.13700},
year={2021}
}
@inproceedings{bcnet,
title={BCNet: Searching for Network Width with Bilaterally Coupled Network},
author={Su, Xiu and You, Shan and Wang, Fei and Qian, Chen and Zhang, Changshui and Xu, Chang},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={2175--2184},
year={2021}
}