A curated list of awesome JAX libraries, projects, and other resources. Inspired by Awesome TensorFlow.
JAX brings automatic differentiation and the XLA compiler together through a numpy-like API for high performance machine learning research on accelerators like GPUs and TPUs. More info here.
- Neural Network Libraries
- Flax - a flexible library with the largest user base of all JAX NN libraries.
- Haiku - focused on simplicity, created by the authors of Sonnet at DeepMind.
- Objax - has an object oriented design similar to PyTorch.
- Elegy - implements the Keras API with some improvements.
- RLax - library for implementing reinforcement learning agent.
- Trax - a "batteries included" deep learning library focused on providing solutions for common workloads.
- Jraph - a lightweight graph neural network library.
- NumPyro - probabilistic programming based on the Pyro library.
- Chex - utilities to write and test reliable JAX code.
- Optax - a gradient processing and optimization library.
- JAX, M.D. - accelerated, differential molecular dynamics.
- Reformer - an implementation of the Reformer (efficient transformer) architecture.
- Introduction to JAX - a simple neural network from scratch in JAX.
- JAX: Accelerated Machine Learning Research | SciPy 2020 | VanderPlas - JAX’s core design, how it’s powering new research, and how you can start using it.
- Bayesian Programming with JAX + NumPyro — Andy Kitchen - introduction to Bayesian modelling using NumPyro.
- Compiling machine learning programs via high-level tracing. Roy Frostig, Matthew James Johnson, Chris Leary. MLSys 2018. - this white paper describes an early version of JAX, detailing how computation is traced and compiled.
- Reformer: The Efficient Transformer. Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. ICLR 2020. - introduces the Reformer architecture with O(nlogn) self attention via locality sensitive hashing, providing significant gains in memory efficiency and speed on long sequences.
- JAX, M.D.: A Framework for Differentiable Physics. Samuel S. Schoenholz, Ekin D. Cubuk. NeurIPS 2020. - introduces JAX, M.D., a differentiable physics library which includes simulation environments, interaction potentials, neural networks, and more.
- Enabling Fast Differentially Private SGD via Just-in-Time Compilation and Vectorization. Pranav Subramani, Nicholas Vadivelu, Gautam Kamath. arXiv 2020. - uses JAX's JIT and VMAP to achieve faster differentially private than existing libraries.
- Using JAX to accelerate our research - describes the state of JAX and the JAX ecosystem at DeepMind.
- Getting started with JAX (MLPs, CNNs & RNNs) - neural network building blocks from scratch with the basic JAX operators.
- Plugging Into JAX - compared Flax, Haiku, and Objax on the Kaggle flower classification challenge.
Contributions welcome! Read the contribution guidelines first.