bug: AnalyticalGaussianIntegrator incorrect
meta-inf opened this issue · 2 comments
meta-inf commented
Bug Report
GPJax version: 0.7.2
Current behavior:
The following code compares the expected Gaussian log likelihood from the analytical integrator and a Monte-Carlo estimation. The results are significantly different (-98 vs -1e4).
import jax, jax.numpy as jnp, gpjax as gpx
f_mean, f_var = jnp.zeros((1,)), jnp.ones((1,))
noise_stddev = jnp.array(0.01)
y = jnp.ones((1,))
likelihood = gpx.Gaussian(num_datapoints=1, obs_stddev=noise_stddev)
print(likelihood.expected_log_likelihood(y[None], f_mean[None], f_var[None])) # input shape should be [N, D]
P = 100000
f_samples = jax.random.normal(jax.random.PRNGKey(23), shape=(P, 1)) * f_var**0.5 + f_mean
log_y_given_f = jax.scipy.stats.norm.logpdf(
jnp.tile(y[None], [P, 1]),
f_samples,
noise_stddev * jnp.ones([P, 1]))
print(log_y_given_f.mean(0), log_y_given_f.std(0) / P**0.5)
Expected behavior:
The results should be similar.
Steps to reproduce:
See above.
Related code:
See above.
Other information:
I believe you forgot to square obs_stddev
in the code here.
daniel-dodd commented
Thanks @meta-inf. Quite right. If you fancy opening a PR - go for it! Or I'll open one in a bit. Your example would also make an excellent test! Cheers, Dan.
meta-inf commented
Fixed in #414 -- thanks @daniel-dodd !