Skip to content

The implementation of our paper: Towards Robust Vision Transformer (CVPR2022)

Notifications You must be signed in to change notification settings

vtddggg/Robust-Vision-Transformer

Repository files navigation

RVT: Towards Robust Vision Transformer

News: We add adversarial training result of RVT here !!

This repository contains PyTorch code for Robust Vision Transformers.

Note: Since the model is trained on our private platform, this transferred code has not been tested and may have some bugs. If you meet any problems, feel free to open an issue!

RVT

For details see our paper "Towards Robust Vision Transformer"

Usage

First, clone the repository locally:

git clone https://github.com/vtddggg/Robust-Vision-Transformer.git

Install PyTorch 1.7.0+ and torchvision 0.8.1+ and pytorch-image-models 0.3.2:

conda install -c pytorch pytorch torchvision
pip install timm==0.3.2

In addition, einops and kornia is required for using this implementation:

pip install einops
pip install kornia

We use 4 nodes with 8 gpus to train RVT-Ti, RVT-S and RVT-B:

Training

RVT-Ti:

python -m torch.distributed.launch --nproc_per_node=8 --nnodes=4 main.py --model rvt_tiny --data-path /path/to/imagenet --output_dir output --dist-eval

RVT-S:

python -m torch.distributed.launch --nproc_per_node=8 --nnodes=4 main.py --model rvt_small --data-path /path/to/imagenet --output_dir output --dist-eval

RVT-B:

python -m torch.distributed.launch --nproc_per_node=8 --nnodes=4 main.py --model rvt_base --data-path /path/to/imagenet --output_dir output --batch-size 32 --dist-eval

You can also finetune the pretrained model by adding --pretrained.

If you want to train RVT-Ti*, RVT-S* or RVT-B*, simply specify --model as rvt_tiny_plus, rvt_small_plus or rvt_base_plus, then add --use_patch_aug to enable patch-wise augmentation.

Testing

News: The robustness evaluation now is supported!! Because of the environmental differences, the results of robustness may have the fluctuations of ±0.1~0.3% compared with paper results.

RVT-Ti:

python main.py --eval --pretrained --model rvt_tiny --data-path /path/to/imagenet

RVT-Ti*:

python main.py --eval --pretrained --model rvt_tiny_plus --data-path /path/to/imagenet

RVT-S:

python main.py --eval --pretrained --model rvt_small --data-path /path/to/imagenet

RVT-S*:

python main.py --eval --pretrained --model rvt_small_plus --data-path /path/to/imagenet

RVT-B:

python main.py --eval --pretrained --model rvt_base --data-path /path/to/imagenet

RVT-B*:

python main.py --eval --pretrained --model rvt_base_plus --data-path /path/to/imagenet

To enable robustness evaluation, please add one of --inc_path /path/to/imagenet-c, --ina_path /path/to/imagenet-a, --inr_path /path/to/imagenet-r or --insk_path /path/to/imagenet-sketch to test ImageNet-C, ImageNet-A, ImageNet-R or ImageNet-Sketch.

If you want to test the accuracy under adversarial attackers, please add --fgsm_test or --pgd_test.

Pretrained weights

Model name FLOPs accuracy weights
rvt_tiny 1.3 G 78.4 link
rvt_small 4.7 G 81.7 link
rvt_base (ImageNet-22k) 17.7 G 83.4 link
rvt_tiny* 1.3 G 79.3 link
rvt_small* 4.7 G 81.8 link
rvt_base* (ImageNet-22k) 17.7 G 83.6 link

Adversarially trained weights

Model name clean accuracy PGD accuracy weights
adv_deit_tiny 52.05 26.55 link
adv_rvt_tiny 54.91 28.1 link

To test these models, run following commands:

python adv_test.py --model deit_tiny_patch16_224 --ckpt_path adv_deit_tiny.pth
python adv_test.py --model rvt_tiny --ckpt_path adv_rvt_tiny.pth

About

The implementation of our paper: Towards Robust Vision Transformer (CVPR2022)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages