/snnax

Spiking neural networks in JAX with equinox.

Primary LanguagePythonMIT LicenseMIT

Project generated with PyScaffold ReadTheDocs PyPI-Server

SNNAX is a lightweight library for implementing Spiking Neural Networks (SNNs) is JAX. It leverages the excellent and intuitive equinox. The full documentation of snnax can be found at https://pgi15.github.io/snnax/.

Installation

You can install SNNAX from PyPI using pip:

pip install snnax

Or you can install the latest version from GitHub using pip:

pip install git+https://github.com/PGI15/snnax

Requires Python 3.9+, JAX 0.4.13+ and Equinox 0.11.1+.

Introduction

SNNAX is a lightweight library that builds on Equinox and JAX to provide a spiking neural network (SNN) simulator for deep learning. It is designed to be easy to use and flexible, allowing users to define their own SNN layers while the common deep learning layers are provided by equinox. It is fully compatible with JAX and thus can fully leverage JAX' function transformation features like vectorization with jax.vmap, automatic differentiationand JIT compilation with XLA.

The following piece of source code demonstrates how to define a simple SNN in SNNAX: We can use the snnax.snn.Sequential class to stack layers of SNNs and Equinox layers into a feed-forward architecture.

import jax
import jax.numpy as jnp

import equinox.nn as nn
import snnax.snn as snn

import optax

model = snn.Sequential(
    nn.Conv2D(2, 32, 7, 2, key=key1),
    snn.LIF([.9, .8], shape=(8, 8), key=key2),
    snn.flatten(),
    nn.Linear(64, 11, key=key3),
    snn.LIF([.9, .8], shape=11, key=key4)
)

Next, we simply define a loss function for a single sample and then use the vectorization features of JAX to create a batched loss function. Note that the output of our model is a tuple of membrane potentials and spikes. The spike output is a list of spike trains for each layer of the SNN. For out example, we can simply sum the spikes along the time axis to get the spike count.

# Simple batched loss function
@partial(jax.vmap, in_axes=(None, 0, 0))
def loss_fn(in_states, in_spikes, tgt_class):
    out_state, out_spikes = model(in_states, in_spikes)

    # Spikes from the last layer are summed across time
    pred = out_spikes.sum(-1)
    loss = optax.softmax_cross_entropy(pred, tgt_class)
    return loss

# Calculating the gradient with Equinox PyTree filters and
# subsequently jitting the resulting function
@eqx.filter_value_and_grad
def loss_and_grad(in_states, in_spikes, tgt_class):
    return jnp.mean(loss_fn(in_states, in_spikes, tgt_class))

# Finally, we update the parameters using a simple optimizer
@eqx.filter_jit
def update(model, opt_state, in_spiked, tgt_class):
    # Get gradients
    loss, grads = loss_and_grad(model, in_spikes, tgt_class)

    # Calculate parameter updates using the optimizer
    updates, opt_state = optim.update(grads, opt_state)

    # Update parameter PyTree with Equinox and optax
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss

Finally, we train the model by feeding our model the input spike trains and states. For this, we first have to initialize the states of the SNN using the init_states method of the Sequential class.

# ...
# Simple training loop
for in_spikes, tgt_class in tqdm(dataloader):
    # Initializing the membrane potentials of LIF neurons
    states = model.init_states(key)
    model, opt_state, loss = update(model, opt_state, states, in_spikes, tgt_class)

Fully worked-out examples can be found in the examples directory.

Citation

If you use SNNAX in your research, please cite the following paper:

@article{lohoff2024snnax,
    title={{SNNAX}: {S}piking {N}eural {N}etworks in {JAX}},
    author={Lohoff, Jamie and Finkbeiner, Jan and Neftci, Emre},
    journal={TBD},
    year={2024}
}

JAX Ecosystem

You can find JAX itself under https://github.com/google/jax.

equinox is available under https://github.com/patrick-kidger/equinox.

Other JAX libraries for SNN training: