/jax-utils

Primary LanguagePythonMIT LicenseMIT

JAX Utils

Convenient functions to train neural networks using JAX and Flax. Intended for personal use, feel free to use but at your own caution!

Installation

pip install --upgrade git+https://github.com/n2cholas/jax-utils.git