JaxGaussianProcesses/GPJax

bug: Objectives not being optimised in the right direction

Opened this issue · 5 comments

Bug Report

GPJax version:

0.9.0

Current behavior:

The new changes from the previous gpjax.objectives.ConjugateMLL to the new gpjax.objectives.conjugate_mll implementation has removed the negative parameter from the previous superclass gpjax.objectives.AbstractObjective without correctly modifying the new code to negate the MLL. Optimisation of model parameters does not work.

Expected behavior:

Optimisation of model parameters should fit model to data.

Steps to reproduce:

Run the Simple Example on the README.

Related code:

To make optimisation work, we now have to do

opt_posterior, history = gpx.fit(
    model=posterior,
    objective=lambda p, d: -1 * gpx.objectives.conjugate_mll(p, d),
    train_data=D,
    optim=optimiser,
    num_iters=500,
    safe=True,
    key=key,
)

instead of providing gpx.objectives.conjugate_mll to the objective parameter.

Other information:

This also applies to log_posterior_density.

Hi @huylenguyen - thanks for catching that the README needs updating. Would you be willing to open a PR to fix the README?

Wonderful! Please add me as a reviewer when the PR is ready.

Hi both,

Small question: has this been incorporated yet?

Hey @miguelgondu - no, I've not had the chance myself and have not heard anything further on the issue.