/quadax

Numerical quadrature with JAX

Primary LanguagePythonMIT LicenseMIT

quadax

License DOI GitHub issues Pypi

Documentation UnitTests Coverage

quadax is a library for numerical quadrature and integration using JAX.

  • vmap-able, jit-able, differentiable.
  • Scalar or vector valued integrands.
  • Finite or infinite domains with discontinuities or singularities within the domain of integration.
  • Globally adaptive Gauss-Kronrod and Clenshaw-Curtis quadrature for smooth integrands (similar to scipy.integrate.quad)
  • Adaptive tanh-sinh quadrature for singular or near singular integrands.
  • Quadrature from sampled values using trapezoidal and Simpsons methods.

Coming soon:

  • Custom JVP/VJP rules (currently AD works by differentiating the loop which isn't the most efficient.)
  • N-D quadrature (cubature)
  • QMC methods
  • Integration with weight functions
  • Sparse grids (maybe, need to play with data structures and JAX)

Installation

quadax is installable with pip:

pip install quadax

Usage

import jax.numpy as jnp
import numpy as np
from quadax import quadgk

fun = lambda t: t * jnp.log(1 + t)

epsabs = epsrel = 1e-5 # by default jax uses 32 bit, higher accuracy requires going to 64 bit
a, b = 0, 1
y, info = quadgk(fun, [a, b], epsabs=epsabs, epsrel=epsrel)
assert info.err < max(epsabs, epsrel*abs(y))
np.testing.assert_allclose(y, 1/4, rtol=epsrel, atol=epsabs)

For full details of various options see the API documentation