Sparse neural networks. Fully compatible with Equinox and Optax. To get started, see the example.
- Wrap
optax
optimizer withsparsenn.flatten
- Replace
eqx.filter_value_and_grad
withsparsenn.filter_value_and_grad
- Replace
eqx.apply_updates
withsparsenn.apply_updates
Use sparsenn.vmap_chunked(f, in_axes=..., chunk_size=..., devices=...)
instead of jax.vmap(f, in_axes=...)
to do memory-limited (chunked with scan
) multi-GPU vmap
.