JaxGaussianProcesses/GPJax

bug: AnalyticalGaussianIntegrator incorrect

meta-inf opened this issue · 2 comments

Bug Report

GPJax version: 0.7.2

Current behavior:

The following code compares the expected Gaussian log likelihood from the analytical integrator and a Monte-Carlo estimation. The results are significantly different (-98 vs -1e4).

import jax, jax.numpy as jnp, gpjax as gpx

f_mean, f_var = jnp.zeros((1,)), jnp.ones((1,))
noise_stddev = jnp.array(0.01)
y = jnp.ones((1,))
likelihood = gpx.Gaussian(num_datapoints=1, obs_stddev=noise_stddev)
print(likelihood.expected_log_likelihood(y[None], f_mean[None], f_var[None]))  # input shape should be [N, D]

P = 100000
f_samples = jax.random.normal(jax.random.PRNGKey(23), shape=(P, 1)) * f_var**0.5 + f_mean
log_y_given_f = jax.scipy.stats.norm.logpdf(
    jnp.tile(y[None], [P, 1]),
    f_samples,
    noise_stddev * jnp.ones([P, 1]))
print(log_y_given_f.mean(0), log_y_given_f.std(0) / P**0.5)

Expected behavior:

The results should be similar.

Steps to reproduce:

See above.

Related code:

See above.

Other information:

I believe you forgot to square obs_stddev in the code here.

Thanks @meta-inf. Quite right. If you fancy opening a PR - go for it! Or I'll open one in a bit. Your example would also make an excellent test! Cheers, Dan.

Fixed in #414 -- thanks @daniel-dodd !