Quickstart | Install guide | Documentation
GPJax aims to provide a low-level interface to Gaussian process (GP) models in Jax, structured to give researchers maximum flexibility in extending the code to suit their own needs. We define a GP prior in GPJax by specifying a mean and kernel function and multiply this by a likelihood function to construct the posterior. The idea is that the code should be as close as possible to the maths we write on paper when working with GP models.
- Conjugate Inference
- Classification with MCMC
- Sparse Variational Inference
- BlackJax Integration
- TensorFlow Probability Integration
- Inference on Non-Euclidean Spaces
- Inference on Graphs
- Learning Gaussian Process Barycentres
- Deep Kernel Regression
This simple regression example aims to illustrate the resemblance of GPJax's API with how we write the mathematics of Gaussian processes.
After importing the necessary dependencies, we'll simulate some data.
import gpjax as gpx
import jax
import jax.numpy as jnp
import jax.random as jr
from jax.example_libraries import optimizers
from jax import jit
key = jr.PRNGKey(123)
x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(50,)).sort().reshape(-1, 1)
y = jnp.sin(x) + jr.normal(key, shape=x.shape)*0.05
training = gpx.Dataset(X=x, y=y)
The function of interest here is sinusoidal, but our observations of it have been perturbed by independent zero-mean Gaussian noise. We aim to utilise a Gaussian process to try and recover this latent function.
We begin by defining a zero-mean Gaussian process prior with a radial basis function kernel and assume the likelihood to be Gaussian.
prior = gpx.Prior(kernel = gpx.RBF())
likelihood = gpx.Gaussian(num_datapoints = x.shape[0])
The posterior is then constructed by the product of our prior with our likelihood.
posterior = prior * likelihood
Equipped with the posterior, we proceed to train the model's hyperparameters through gradient-optimisation of the marginal log-likelihood.
We begin by defining a set of initial parameter values through the initialise
callable.
params, _, constrainer, unconstrainer = gpx.initialise(posterior)
params = gpx.transform(params, unconstrainer)
Next, we define the marginal log-likelihood, adding Jax's just-in-time (JIT) compilation to accelerate training. Notice that this is the first instance of incorporating data into our model. Model building works this way in principle too, where we first define our prior model, then observe some data and use this data to build a posterior.
mll = jit(posterior.marginal_log_likelihood(training, constrainer, negative=True))
Finally, we utilise Jax's built-in Adam optimiser and run an optimisation loop.
opt_init, opt_update, get_params = optimizers.adam(step_size=0.01)
opt_state = opt_init(params)
def step(i, opt_state):
params = get_params(opt_state)
gradients = jax.grad(mll)(params)
return opt_update(i, gradients, opt_state)
for i in range(100):
opt_state = step(i, opt_state)
Now that our parameters are optimised, we transform these back to their original constrained space. Using their learned values, we can obtain the posterior distribution of the latent function at novel test points.
final_params = gpx.transform(get_params(opt_state), constrainer)
xtest = jnp.linspace(-3., 3., 100).reshape(-1, 1)
latent_distribution = posterior(training, final_params)(xtest)
predictive_distribution = likelihood(latent_distribution, params)
predictive_mean = predictive_distribution.mean()
predictive_stddev = predictive_distribution.stddev()
To install the latest stable version of GPJax run
pip install gpjax
To install the latest, possibly unstable, version, the following steps should be followed. It is by no means compulsory, but we do advise that you do all of the below inside a virtual environment.
git clone https://github.com/thomaspinder/GPJax.git
cd GPJax
python setup.py develop
We then recommend you check your installation using the supplied unit tests.
python -m pytest tests/