This repo contains the official PyTorch code and pre-trained models for FLatten Transformer (ICCV 2023).
- May 28 2024: Fix numerical instability problem. Now FLatten Transformers can be trained with auto mixed precision (amp) or float16.
The quadratic computation complexity of self-attention
In this paper, we first perform a detailed analysis of the inferior performances of linear attention from two perspectives: focus ability and feature diversity. Then, we introduce a simple yet effective mapping function and an efficient rank restoration module and propose our Focused Linear Attention (FLatten) which adequately addresses these concerns and achieves high efficiency and expressive capability.
- Comparison of different models on ImageNet-1K.
- Accuracy-Runtime curve on ImageNet.
- Python 3.9
- PyTorch == 1.11.0
- torchvision == 0.12.0
- numpy
- timm == 0.4.12
- einops
- yacs
The ImageNet dataset should be prepared as follows:
$ tree data
imagenet
├── train
│ ├── class1
│ │ ├── img1.jpeg
│ │ ├── img2.jpeg
│ │ └── ...
│ ├── class2
│ │ ├── img3.jpeg
│ │ └── ...
│ └── ...
└── val
├── class1
│ ├── img4.jpeg
│ ├── img5.jpeg
│ └── ...
├── class2
│ ├── img6.jpeg
│ └── ...
└── ...
Based on different model architectures, we provide several pretrained models, as listed below.
model | Reso | acc@1 | config | pretrained weights |
---|---|---|---|---|
FLatten-PVT-T | 77.8 (+2.7) | config | TsinghuaCloud | |
FLatten-PVTv2-B0 | 71.1 (+0.6) | config | TsinghuaCloud | |
FLatten-Swin-T | 82.1 (+0.8) | config | TsinghuaCloud | |
FLatten-Swin-S | 83.5 (+0.5) | config | TsinghuaCloud | |
FLatten-Swin-B | 83.8 (+0.3) | config | TsinghuaCloud | |
FLatten-Swin-B | 85.0 (+0.5) | config | TsinghuaCloud | |
FLatten-CSwin-T | 83.1 (+0.4) | config | TsinghuaCloud |
Evaluate one model on ImageNet:
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg <path-to-config-file> --data-path <imagenet-path> --output <output-path> --eval --resume <path-to-pretrained-weights>
Outputs of the four T/B0 pretrained models are:
[2023-07-21 07:50:09 flatten_pvt_tiny] (main.py 294): INFO * Acc@1 77.758 Acc@5 93.910
[2023-07-21 07:50:09 flatten_pvt_tiny] (main.py 149): INFO Accuracy of the network on the 50000 test images: 77.8%
[2023-07-21 07:51:36 flatten_pvt_v2_b0] (main.py 294): INFO * Acc@1 71.098 Acc@5 90.596
[2023-07-21 07:51:36 flatten_pvt_v2_b0] (main.py 149): INFO Accuracy of the network on the 50000 test images: 71.1%
[2023-07-21 07:46:13 flatten_swin_tiny_patch4_224] (main.py 294): INFO * Acc@1 82.106 Acc@5 95.900
[2023-07-21 07:46:13 flatten_swin_tiny_patch4_224] (main.py 149): INFO Accuracy of the network on the 50000 test images: 82.1%
[2023-07-21 07:52:46 FLatten_CSWin_tiny](main.py 294): INFO * Acc@1 83.130 Acc@5 96.376
[2023-07-21 07:52:46 FLatten_CSWin_tiny](main.py 149): INFO Accuracy of the network on the 50000 test images: 83.1%
- To train
FLatten-PVT-T/S/M/B
on ImageNet from scratch, run:
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_t.yaml --data-path <imagenet-path> --output <output-path> --find-unused-params
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_s.yaml --data-path <imagenet-path> --output <output-path> --find-unused-params
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_m.yaml --data-path <imagenet-path> --output <output-path> --find-unused-params
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_b.yaml --data-path <imagenet-path> --output <output-path> --find-unused-params
- To train
FLatten-PVT-v2-b0/1/2/3/4
on ImageNet from scratch, run:
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b0.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b1.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b2.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b3.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b4.yaml --data-path <imagenet-path> --output <output-path>
- To train
FLatten-Swin-T/S/B
on ImageNet from scratch, run:
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_swin_t.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_swin_s.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_swin_b.yaml --data-path <imagenet-path> --output <output-path>
- To train
FLatten-CSwin-T/S/B
on ImageNet from scratch, run:
python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/flatten_cswin_t.yaml --data-path <imagenet-path> --output <output-path> --model-ema --model-ema-decay 0.99984
python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/flatten_cswin_s.yaml --data-path <imagenet-path> --output <output-path> --model-ema --model-ema-decay 0.99984
python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/flatten_cswin_b.yaml --data-path <imagenet-path> --output <output-path> --model-ema --model-ema-decay 0.99982
Fine-tune a FLatten-Swin-B
model pre-trained on 224x224 resolution to 384x384 resolution:
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_swin_b_384.yaml --data-path <imagenet-path> --output <output-path> --pretrained <path-to-224x224-pretrained-weights>
Fine-tune a FLatten-CSwin-B
model pre-trained on 224x224 resolution to 384x384 resolution:
python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/flatten_cswin_b_384.yaml --data-path <imagenet-path> --output <output-path> --pretrained <path-to-224x224-pretrained-weights> --model-ema --model-ema-decay 0.99982
We provide code for visualizing flatten attention. For example, to visualize flatten attention in FLatten-Swin-T, add the following to this line.
from visualize import AttnVisualizer
visualizer = AttnVisualizer(qk=[q, k], kernel=self.dwc.weight, name='flatten_swin_t')
visualizer.visualize_all_attn(max_num=196, image='./visualize/img_ori_00809.png')
Then run:
python visualize.py
Note: Don't forget to modify the path of FLatten-Swin-T pretrained weight in visualize.py
.
This code is developed on the top of Swin Transformer. The computational resources supporting this work are provided by Hangzhou High-Flyer AI Fundamental Research Co.,Ltd
If you find this repo helpful, please consider citing us.
@InProceedings{han2023flatten,
title={FLatten Transformer: Vision Transformer using Focused Linear Attention},
author={Han, Dongchen and Pan, Xuran and Han, Yizeng and Song, Shiji and Huang, Gao},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
year={2023}
}
If you have any questions, please feel free to contact the authors.
Dongchen Han: hdc23@mails.tsinghua.edu.cn
Xuran Pan: pxr18@mails.tsinghua.edu.cn