mpi4jax
enables zero-copy, multi-host communication of JAX arrays, even from jitted code and from GPU memory.
The JAX framework has great performance for scientific computing workloads, but its multi-host capabilities are still limited.
With mpi4jax
, you can scale your JAX-based simulations to entire CPU and GPU clusters (without ever leaving jax.jit
).
In the spirit of differentiable programming, mpi4jax
also supports differentiating through some MPI operations.
mpi4jax
is available through pip
and conda
:
$ pip install mpi4jax # Pip
$ conda install -c conda-forge mpi4jax # conda
Our documentation includes some more advanced installation examples.
from mpi4py import MPI
import jax
import jax.numpy as jnp
import mpi4jax
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
@jax.jit
def foo(arr):
arr = arr + rank
arr_sum, _ = mpi4jax.allreduce(arr, op=MPI.SUM, comm=comm)
return arr_sum
a = jnp.zeros((3, 3))
result = foo(a)
if rank == 0:
print(result)
Running this script on 4 processes gives:
$ mpirun -n 4 python example.py
[[6. 6. 6.]
[6. 6. 6.]
[6. 6. 6.]]
allreduce
is just one example of the MPI primitives you can use. See all supported operations here.
We use pre-commit hooks to enforce a common code format. To install them, just run:
$ pip install pre-commit
$ pre-commit install
You can set the environment variable MPI4JAX_DEBUG
to 1
to
enable debug logging every time an MPI primitive is called from within a
jitted function. You will then see messages like this:
$ MPI4JAX_DEBUG=1 mpirun -n 2 python send_recv.py
r0 | MPI_Send -> 1 with tag 0 and token 7fd7abc5f5c0
r1 | MPI_Recv <- 0 with tag -1 and token 7f9af7419ac0
- Filippo Vicentini @PhilipVinc
- Dion Häfner @dionhaefner