This repository is the official PyTorch implementation of the paper: Training High-Performance Low-Latency Spiking Neural Networks by Differentiation on Spike Representation (CVPR 2022).
- Python >= 3.6
- PyTorch >= 1.7.1
- torchvision
- Python packages:
pip install numpy matplotlib progress
Please run the following code. The hyperparameters in the code are the same as in the paper.
python -u cifar/main_single_gpu.py --path ./data --dataset cifar10 --model preresnet.resnet18_lif --name [checkpoint_name]
python -u cifar/main_single_gpu.py --path ./data --dataset cifar10 --model preresnet.resnet18_if --Vth 6 --alpha 0.5 --Vth_bound 0.01 --name [checkpoint_name]
python -u cifar/main_single_gpu.py --path ./data --dataset cifar100 --model preresnet.resnet18_lif --name [checkpoint_name]
python -u cifar/main_single_gpu.py --path ./data --dataset cifar100 --model preresnet.resnet18_if --Vth 6 --alpha 0.5 --Vth_bound 0.01 --name [checkpoint_name]
python -u cifar/main_single_gpu.py --path ./data/CIFAR10DVS --dataset CIFAR10DVS --model vgg.vgg11_lif --lr=0.05 --epochs=300 --name [checkpoint_name]
python -u cifar/main_single_gpu.py --path ./data/CIFAR10DVS --dataset CIFAR10DVS --model vgg.vgg11_if --Vth 6 --alpha=0.5 --Vth_bound 0.01 --lr=0.05 --epochs=300 --name [checkpoint_name]
DSR can achieve good results for low latency (e.g., T=5) by tuning the hyperparameters. The code for training with T=5 on CIFAR-10 is shown below. The accuracy is near 94.45% (Fig.3 in the paper). For other datasets, please reduce lr and tune Vth for better performance.
python -u cifar/main_single_gpu.py --path ./data --dataset cifar10 --model preresnet.resnet18_lif --timesteps 5 --lr 0.05 --Vth 0.6 --alpha 0.5 --Vth_bound 0.001 --delta_t 0.1
python -u cifar/main_single_gpu.py --path ./data --dataset cifar10 --model preresnet.resnet18_if --timesteps 5 --lr 0.05 --Vth 6 --alpha 0.5 --Vth_bound 0.01
For the CIFAR-10, CIFAR-100, and DVS-CIFAR10 tasks, multiple GPUs can also be used. The example code is shown below.
python -u -m torch.distributed.launch --nproc_per_node [number_of_gpus] cifar/main_multiple_gpus.py --path ./data --dataset cifar10 --model preresnet.resnet18_lif --name [checkpoint_name]
For the ImageNet classification task, we conduct hybrid training.
First, we train an ANN.
python imagenet/main.py --arch preresnet_ann.resnet18 --data ./data/imagenet --name model_ann --optimizer SGD --wd 1e-4 --batch-size 256 --lr 0.1
Then, we calculate the maximum post-activation as the initialization for spike thresholds.
python imagenet/main.py --arch preresnet_cal_Vth.resnet18 --data ./data/imagenet --pre_train model_ann.pth --calculate_Vth resnet18_Vth
Next, we train the SNN.
python imagenet/main.py --dist-url tcp://127.0.0.1:20500 --dist-backend nccl --multiprocessing-distributed --world-size 1 --rank 0 --arch preresnet_snn.resnet18_if --data ./data/imagenet --pre_train model_ann.pth --load_Vth resnet18_Vth.dict
The pretrained ANN model and calculated thresholds can be downloaded from here and here. Please put them in the path ./checkpoint/imagenet.
The code for the data preprocessing of DVS-CIFAR10 is based on the spikingjelly repo. The code for some utils are from the pytorch-classification repo.