Examples with Jax + Flax framework.
conda create -n jax-m2 python=3.11
conda activate jax-m2
pip install -r requirements.txt
To run any of the following,
python {folder-name}/main.py
Folders:
basics/
: Random basic stuff with jax.-
mnist/
: A minimalistic ConvNet for MNIST digit classication. -
nf
: A basic normalizing flow implementation.