Overview | Example | Installation | UCL Documentation
jaxdf is a JAX-based package defining a coding framework for writing differentiable numerical simulators with arbitrary discretizations.
The intended use is to build numerical models of physical systems, such as wave propagation, or the numerical solution of partial differential equations, that are easy to customize to the user's research needs. Such models are pure functions that can be included into arbitray differentiable programs written in JAX: for example, they can be used as layers of neural networks, or to build a physics loss function.
This is a branch fixed to use jax==0.2.29
and jaxlib==0.1.77
which was compiled on Windows 11 with
- Visual Studio 2019 (version 16.11.16)
- Cuda 11.3
- cuDNN 8.2.1
- python 3.7
with a NVIDIA GTX 1650 graphs card.
The following script builds the non-linear operator (∇2 + sin), using a Fourier spectral discretization on a square 2D domain, and uses it to define a loss function whose gradient is evaluated using JAX Automatic Differentiation.
from jaxdf import operators as jops
from jaxdf import FourierSeries, operator
from jaxdf.geometry import Domain
from jax import numpy as jnp
from jax import jit, grad
# Defining operator
@operator
def custom_op(u):
grad_u = jops.gradient(u)
diag_jacobian = jops.diag_jacobian(grad_u)
laplacian = jops.sum_over_dims(diag_jacobian)
sin_u = jops.compose(u)(jnp.sin)
return laplacian + sin_u
# Defining discretizations
domain = Domain((128, 128), (1., 1.))
parameters = jnp.ones((128,128,1))
u = FourierSeries(parameters, domain)
# Define a differentiable loss function
@jit
def loss(u):
v = custom_op(u)
return jnp.mean(jnp.abs(v.on_grid)**2)
gradient = grad(loss)(u) # gradient is a FourierSeries
Before installing jaxdf
, make sure that you have installed JAX. Follow the instruction to install JAX with NVidia GPU support if you want to use jaxdf
on the GPUs.
Install jaxdf by cloning the repository or downloading and extracting the compressed archive. Then navigate in the root folder in a terminal, and run
pip install -r .requirements/requirements.txt
pip install .
This package will be presented at the Differentiable Programming workshop at NeurIPS 2021.
@article{stanziola2021jaxdf,
author={Stanziola, Antonio and Arridge, Simon and Cox, Ben T. and Treeby, Bradley E.},
title={A research framework for writing differentiable PDE discretizations in JAX},
year={2021},
journal={Differentiable Programming workshop at Neural Information Processing Systems 2021}
}
- Some of the packaging of this repository is done by editing this templace from @rochacbruno
- The multiple-dispatch method employed is based on
plum
: https://github.com/wesselb/plum
odl
Operator Discretization Library (ODL) is a python library for fast prototyping focusing on (but not restricted to) inverse problems.deepXDE
: a TensorFlow and PyTorch library for scientific machine learning.SciML
: SciML is a NumFOCUS sponsored open source software organization created to unify the packages for scientific machine learning.