JAX Utils Convenient functions to train neural networks using JAX and Flax. Intended for personal use, feel free to use but at your own caution! Installation pip install --upgrade git+https://github.com/n2cholas/jax-utils.git