/OTTT-SNN

[NeurIPS 2022] Online Training Through Time for Spiking Neural Networks

Primary LanguagePython

OTTT-SNN

This is the PyTorch implementation of paper: Online Training Through Time for Spiking Neural Networks (NeurIPS 2022). [arxiv].

Dependencies and Installation

Training

For OTTT$_A$, run as following:

python train_cifar.py -data_dir path_to_data_dir -dataset cifar10 -out_dir log_checkpoint_name -gpu-id 0

# For VGG-F model
python train_cifar.py -data_dir path_to_data_dir -dataset cifar100 -out_dir log_checkpoint_name -gpu-id 0 -model online_spiking_vgg11f_ws

python train_cifar10dvs.py -data_dir path_to_data_dir -out_dir log_checkpoint_name -gpu-id 0

python train_imagenet.py -data_dir path_to_data_dir -out_dir log_checkpoint_name -gpu-id 0

For OTTT$_O$, add the argument -online_update as:

python train_cifar.py -data_dir path_to_data_dir -dataset cifar10 -out_dir log_checkpoint_name -gpu-id 0 -online_update

The default hyperparameters in the code are the same as in the paper.

Note: Current codes only support single-gpu training.

Testing

We provide the example code to calculate the firing rate statistics during evaluation. Run as following:

python get_rate_cifar.py -data_dir path_to_data_dir -dataset cifar10 -gpu-id 0 -resume path_to_checkpoint

python get_rate_imagenet.py -data_dir path_to_data_dir -gpu-id 0 -resume path_to_checkpoint

Some pretrained models can be downloaded from Google Drive or Baidu Drive (extraction code: gppq).

Acknowledgement

Some codes for the neuron model and data prepoccessing are adapted from the spikingjelly repository, and the codes for some utils are from the pytorch-classification repository.

Contact

If you have any questions, please contact mingqing_xiao@pku.edu.cn.