Masked Spiking Transformer, ICCV'23: [Paper]. [Poster]. [Video]
The combination of Spiking Neural Networks (SNNs) and Transformers has attracted significant attention due to their potential for high energy efficiency and high-performance nature. However, there still remains a considerable challenge to achieving performance comparable to artificial neural networks in real-world applications . Besides, the mainstream method for ANN-to-SNN conversion trade increased simulation time and power for high-performance SNN, and generally consumes lots of computational resources when making inference.
To address this issue, we propose an energy-efficient architecture, the Masked Spiking Transformer, that combines the benefits of SNNs and the high-performance self-attention mechanism in Transformer utilizing the ANN-to-SNN conversion methods. Furthermomre, our method, called Random Spike Masking, prunes input spikes during both training and inference to reduce SNN computational costs.
The Masked Spiking Transformer combines the self-attention mechanism and the ANN-to-SNN conversion method, achieving state-of-the-art accuracy on both static and neuromorphic datasets. Experimental results demonstrate the RSM method reduces redundant spike operations while keeping model performance over a certain range of mask rates across various model architectures. For instance, the RSM method reduces MST model power by 26.8% at a 75% mask rate with no performance drop.
Checkpoints:
MST( 0% masking rate): Cifar-10, Cifar-100, Imagenet
MST( 75% masking rate): Cifar-10, Cifar-100, Imagenet
For more training details, please check out our paper and supplementary material. (Note: we used 8×3090 GPU cards for training)
torchrun --nproc_per_node 8 main.py --cfg configs/mst/MST.yaml --batch-size 128 --masking_ratio masking_rate
torchrun --nproc_per_node 8 main.py --cfg configs/mst/MST.yaml --batch-size 128 --snnvalidate True --sim_len 128 --pretrained /path/to/weight/ --dataset imagenet --masking_ratio masking_rate
--sim_len
: timestep of SNN.--snnvalidate
: enalbes SNN validation.--dataset
: name of dataset, choice=['imagenet', 'Cifar100', 'Cifar10'].