
The source code of paper Human-Level Control through Directly-Trained Deep Spiking Q-Networks.

Primary LanguagePythonMIT LicenseMIT


The source code of paper Human-Level Control through Directly-Trained Deep Spiking Q-Networks.

Besides the implementation of our Deep Spiking Q-Network (DSQN), we also reproduced the vanilla Deep Q-Network (DQN) proposed by (Mnih et al. 2015) and conversion-based Spiking Neural Network (SNN) proposed by (Tan, Patel, and Kozma 2021).


For DSQN and the vanilla DQN:

  • Python 3.8.11
  • PyTorch 1.8.2 LTS
  • SpikingJelly
  • Cupy
  • Gym[atari,accept-rom-license]
  • PyYAML
  • Tensorboard

For conversion-based SNN:

  • Python 3.8.11
  • PyTorch 1.8.2 LTS
  • TensorFlow-GPU 1.14
  • OpenAI Baselines
  • Gym[atari,accept-rom-license]
  • PyYAML
  • Matplotlib
  • Pandas
  • TensorboardX
  • Sklearn


This project could be devided into three parts:

  1. The codes for DSQN and the vanilla DQN are under the root directory of this project.

    │  .gitignore
    │  LICENSE
    │  log_config.yaml
    │  README.md
    │  run.py	# Program entry.
    ├─agent	# The module for RL agents.
    │      agent.py
    │      __init__.py
    ├─environment	# The module for Gym Atari environments.
    │     atari_wrappers.py
    │     wrappers.py
    │     __init__.py
    ├─neural_network	# The module for the vanilla DQN and our DSQN.
    │      neural_network.py
    │      __init__.py
    └─replay_memory	# The module for experience replay memory.
  2. The codes for the conversion-based SNN are under conversion_based_snn directory.

    │  convert.py	# Program entry.
    │  LICENSE
    │  README.md
    └─bindsnet	# The module for implementing the conversion method.
        │  utils.py
        │  __init__.py
        │      pipeline_analysis.py
        │      plotting.py
        │      visualization.py
        │      __init__.py
        │      conversion.py
        │      __init__.py
        │      alov300.py
        │      collate.py
        │      dataloader.py
        │      davis.py
        │      preprocess.py
        │      README.md
        │      spoken_mnist.py
        │      torchvision_wrapper.py
        │      __init__.py
        │      encoders.py
        │      encodings.py
        │      loaders.py
        │      __init__.py
        │      environment.py
        │      __init__.py
        │      evaluation.py
        │      __init__.py
        │      learning.py
        │      reward.py
        │      __init__.py
        │      models.py
        │      __init__.py
        │      monitors.py
        │      network.py
        │      nodes.py
        │      topology.py
        │      __init__.py
        │      action.py
        │      base_pipeline.py
        │      dataloader_pipeline.py
        │      environment_pipeline.py
        │      __init__.py
  3. The codes for ploting images are under utils directory.

        plot.py	# Program entry.


For example, you could use following command to train our DSQN on Atari game Breakout with a single GPU by default:

python run.py --model=Dsqn --env_id=BreakoutNoFrameskip-v4

To train the vanilla DQN on Breakout, use:

python run.py --model=Dqn --env_id=BreakoutNoFrameskip-v4

To convert the vanilla DQN, which has been pre-trained on Breakout and saved under model directory by default, to SNN, first you need to locate under the conversion_based_SNN directory. Then, execute the following command:

python convert.py --game=BreakoutNoFrameskip-v4 --model=../model/Dqn_BreakoutNoFrameskip-v4_[time].pkl

More detailed arguments could be found in run.py and conversion_based_snn/convert.py. The default values of all arguments are as same as those reported in our paper.


The codes under conversion_based_snn directory were forked from the open-source code of (Tan, Patel, and Kozma 2021), which could be accessed at here. We completed the reproduction based on their open-source code. Thanks a lot for their work!