/awesome-jax

JAX - A curated list of resources https://github.com/google/jax

Creative Commons Zero v1.0 UniversalCC0-1.0

Awesome JAX AwesomeJAX Logo

JAX brings automatic differentiation and the XLA compiler together through a numpy-like API for high performance machine learning research on accelerators like GPUs and TPUs.

This is a curated list of awesome JAX libraries, projects, and other resources. Contributions are welcome!

Contents

  • Neural Network Libraries
    • Flax - Centered on flexibility and clarity.
    • Haiku - Focused on simplicity, created by the authors of Sonnet at DeepMind.
    • Objax - Has an object oriented design similar to PyTorch.
    • Elegy - A framework-agnostic Trainer interface for the Jax ecosystem. Supports Flax, Haiku, and Optax.
    • RLax - Library for implementing reinforcement learning agents.
    • Trax - "Batteries included" deep learning library focused on providing solutions for common workloads.
    • Jraph - Lightweight graph neural network library.
    • Neural Tangents - High-level API for specifying neural networks of both finite and infinite width.
  • NumPyro - Probabilistic programming based on the Pyro library.
  • Chex - Utilities to write and test reliable JAX code.
  • Optax - Gradient processing and optimization library.
  • JAX, M.D. - Accelerated, differential molecular dynamics.
  • Coax - Turn RL papers into code, the easy way.
  • SymJAX - Symbolic CPU/GPU/TPU programming.
  • mcx - Express & compile probabilistic programs for performant inference.

This section contains libraries that are well-made and useful, but have not necessarily been battle-tested by a large userbase yet.

  • Neural Network Libraries
    • FedJAX - Federated learning in JAX, built on Optax and Haiku.
  • jax-unirep - Library implementing the UniRep model for protein machine learning applications.
  • jax-flows - Normalizing flows in JAX.
  • sklearn-jax-kernels - scikit-learn kernel matrices using JAX.
  • jax-cosmo - Differentiable cosmology library.
  • efax - Exponential Families in JAX.
  • mpi4jax - Combine MPI operations with your Jax code on CPUs and GPUs.
  • imax - Image augmentations and transformations.

This section contains papers focused on JAX (e.g. JAX-based library whitepapers, research on JAX, etc). Papers implemented in JAX are listed in the Models/Projects section.

Contributing

Contributions welcome! Read the contribution guidelines first.