CosmoStat/jax-lensing

Implement simple VI

Closed this issue · 4 comments

EiffL commented

As a first goal, we can try to use a simple Gaussian prior on the convergence map, try to recover a posterior on the map by Variational Inference.

For this, we will be using the ELBO, and solving the problem by SVI. It's more or less using the same technique as in a Variational Auto-Encoder.

Here is a nice article summarizing how to apply VI in practice, we don't need to adopt everything there, but the equations for the ELBO can be useful:
https://arxiv.org/abs/1603.00788

Also, the original VAE paper is not a bad reference for VI:
https://arxiv.org/abs/1312.6114

EiffL commented

@b-remy So step I is almost done if I understand correctly :-)
One thing to try before moving to the next steps would be test out the uncertainty estimation when we have missing data. So for instance imagine you have a hole in the middle of your survey, if everything works well, we should see the uncertainty go way up in that region.

To do this, your likelihood should be something like

|| M ( gamma - P k) ||_2

Where M is mask 0 or 1 depending on whether a pixel in the shear map is observed or not. So for instance, you could set M to 1, except for a small patch in the middle of the field

EiffL commented

Perfect :-D ! Can you post here some of the plots you got? So that we can have a trace here as well of your progress? I'm thinking in particular about the plot showing the reconstruction mean and standard deviation.

I've looked at your code, and it looks pretty good :-) I think you are ready to open a Pull Request, right?

We can also talk about whether you want to already integrate some of your notebook code into a library. But this may be a tiny bit premature, maybe we should try to solve #2 first.

After implementing a point estimation of the convergence map here, posterior on the map have been investigated by VI, using a simple Gaussian prior and a Gaussian variational distribution here.

From an input shear with a mask of zero values (missing data):
shear

the reconstructed convergence maps look like this:
convergence
We can observe presence of peaks in E-mode, and no pattern in B-mode as expected. However, the reconstruction is not perfect, maybe the gaussian assumptions on the prior and the variational distribution don't provide enough flexibility.

Missing data are handled in this model by a higher uncertainty in the corresponding region.
std

A description of the method can be found in the notebook, essentially based on maximizing the ELBO function:
elbo

In deed it seems to be a bit premature to integrate code into a library, so I jump to #2.

EiffL commented

Very nice :-) thank you for that update!