/JaxTutos

JAX Tutorial notebooks : basics, crash & tips, usage of optax/JaxOptim/Numpyro

Primary LanguageJupyter NotebookEducational Community License v2.0ECL-2.0

JaxTutos

This repository provides some notebooks to learn JAX (basics and advenced) and use some libraries such as JaxOptim/Numpyro/...

This work was partily performed using resources from GENCI–IDRIS (Grant 2024-[AD010413957R1]).

Questions:

  • Why JAX?: You need Auto-differention first and want a code accelerated ready on CPU/GPU/TPU devices, you probably already know a bit of Python.
  • Does my code will be scalable? Gemini (ie. the Google ChatGPT alter-ego) is coded in JAX so I guess you will find good framework to get your use-case working nicely.

Exchanges:

  • To discuss you can use the Discussions menu
  • To suggest new notebooks, code changes and/or report bugs use Issues.

Here the list of Tutos in this repo:

A tour of some basics

  • JAX_get_started.ipynb : get a flavour of the coding and exemple of auto-diff
  • JAX_AutoDiff_UserCode.ipynb : more on usage of auto diff in real user-case "integration methods"
  • JIT_fractals.ipynb : (GPU better) with some fractal images production discover some control flow jax.lax functions and nested vmap
  • JAX_control_flow.ipynb: jax.lax control flow (fori_loop/scan/while_loop, cond). Some "crashes" are analysed.
  • JAX_exo_sum_image_patches.ipynb: Exercice: sum patches of identical size from a 2D image. Experience the compilation/execution times differences of different methods on CPU and GPU (if possible).
  • JAX-MultiGPus.ipynb : (4 GPUs)* (eg. on Jean Zay jupytyterHub plateform) use the "data sharding module" to distribute arrays and perform parallelization (2D image productions: simple 2d function and revisit of Julia set from JIT_fractals.ipynb.

More advanced topics:

Designed for people with OO thinking (C++/Python), and/or with in mind to existing code to transform into JAX. Based on real use case I experienced. This is more advenced and technical but with with "crashes" analysed

Using JAX & some thrid party libraries for real job

  • JAX_jaxopt_optax.ipynb: some use of JaxOptim & Optax libraries
  • JAX_MC_Sampling.ipynb: pedagogical nb for Monte Carlo Sampling using different techniques. Beyond the math, one experiences the random number generation in JAX which by itself can be a subject. I implement a simple HMC MCMC both in Python and JAX to see the difference.
  • Numpyro_MC_Sampling.ipynb: here we give some simple examples using the Numpyro Probabilistic Programming Language
  • JAX-GP-regression-piecewise.ipynb: (Not ready for Calob) my implementation of Gaussian Processe library to see differences with Sklearn et GPy.

Other TUTOs (absolutly not exhaustive...)

Other JAX librairies:

  • Have a look at awesome-jax
  • More Cosmo-centred:
    • JaxPM: JAX-powered Cosmological Particle-Mesh N-body Solver
    • S2FFT: JAX package for computing Fourier transforms on the sphere and rotation group
    • JAX-Cosmo: a differentiable cosmology library in JAX
    • JAX-GalSim: JAX version (paper in draft version) of the C++ Galsim code (GalSim is open-source software for simulating images of astronomical objects (stars, galaxies) in a variety of ways)
    • CosmoPower-JAX: example of cosmological power spectra emulator in a differentiable way
    • DISCO-DJ I: a differentiable Einstein-Boltzmann solver for cosmology (here): not yet public repo.
  • and many others concerning for instance Simulation Based Inference...

Installation (it depends on your local environment)

Most of the nbs are running on Colab. (JAX 0.4.2x)

If you want an environement Conda JaxTutos (but this is not garanteed to work due to the local & specific cuda library to be used for the GPU-based nb)

conda create -n JaxTutos python [>= 3.8]
conda activate JaxTutos
pip install --upgrade "jax[cuda]==<XYZ>" -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install numpyro>=0.10.1
pip install jaxopt>+0.6
pip install optax>=0.1.4
pip install corner>=2.2.1
pip install arviz>=0.11.4
pip install matplotlib_inline
pip install seaborn>=0.12.2

Notice that starting JAX v0.4.30 the install changes: see CHANGELOG

Some Docs