Codes I am writing or modifying while learning the mighty JAX. See learn-pytorch.
Clone the repo. On the terminal, cd learn-jax
to go inside repo, and run pip install -r requirements.txt
. Make a directory datasets
in the root directory/repo.
- Image classification, see (WIP)
- Image segmentation
- https://github.com/NobuoTsukamoto/jax_examples
- https://github.com/google/flax/tree/main/examples (add segmentation example PR?)
- https://github.com/huggingface/transformers/tree/main/examples/flax
- https://github.com/n2cholas/awesome-jax
- https://github.com/craffel/jax-tutorial
- https://github.com/google/jax#quickstart-colab-in-the-cloud
- https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#
- https://github.com/gordicaleksa/get-started-with-JAX
- https://github.com/8bitmp3/JAX-Flax-Tutorial-Image-Classification-with-Linen
- https://wandb.ai/wandb_fc/tips/reports/How-To-Create-an-Image-Classification-Model-in-JAX-Flax--VmlldzoyMjA0Mjk1