Implementing spline interpolations
Closed this issue · 13 comments
JAX doesn't have yet all of the scipy interpolation tools, so we need our own.
Currently we only have a trivial linear interpolant here:
But would be great higher precision interpolation methods, this reduces the cost/size of interpolation tables for a given accuracy, and that is the main limiting factor in the growth or comoving distance calculation right now.
A spline interpolation method similar to
https://docs.scipy.org/doc/scipy-0.18.1/reference/generated/scipy.interpolate.UnivariateSpline.html#scipy.interpolate.UnivariateSpline
@austinpeel said he would look into this :-)
I've written a class that attempts to mimic the behaviour of scipy's InterpolatedUnivariateSpline using JAX. So far it supports linear (k=1), quadratic (k=2), and cubic (k=3) interpolations (scipy's goes up to k=5). From testing, it seems to match very well for k=1 and k=3, but there are still some discrepancies with k=2. (But who uses quadratic interpolation anyway? ;))
The heavy lifting is done up front when the class is instantiated, meaning the spline coefficients are all solved for and stored straight away. You'll see there are some ops.index_update() calls that facilitate this calculation and probably need to be avoided eventually... But calls to evaluate the function and its derivatives (either analytic or with jax.grad) are jitted and working and fast.
I haven't looked into the pytrees stuff yet. I'm sure there will still be some restructurings necessary to make the code fit smoothly into the larger jax_cosmo framework.
Lool, so a quick look at it turned into an evening quest ^^' But now I (almost) understand splines :-)
So, it's very weird. I first tried your code on an exponential function, just like in the scipy doc, and it worked perfectly, but then I tried with some of the functions we want to interpolated in jax-cosmo, and all hell broke loose essentially ^^' So idk, I actually couldn't figure out what was wrong, but here is the kind of results:
The green line is the interpolation error using the scipy cubic spline, and the red line is the jax-cosmo cubic spline.... not a great agreement ^^'
So then I tried to investigate, couldn't figure out, exactly, so I tried to reimplement from scracth from some equations I found on the internet, and it almost looks like it's working :-) Also this re-implemention doesn't use the update calls ;-)
Here is my experimentation notebook:
https://gist.github.com/EiffL/cdbab7d59a97854587e1a43dab97c220
I hope it's readable, but I didn't put a lot of comments :-|
In any case, these splines are super exciting, the accuracy is orders of magnitudes better than the stupid linear interpolation I had, so the code is gonna be so much faster and more accurate once they are in :-D
Yes indeed, good catch! The problem is due to using non-uniformly spaced knots. Here's an example of fitting a Gaussian again with a cubic spline using two different knot spacings.
These are errors on the true values for scipy and our jax implementation. I reformulated the system of equations to explicitly account for the different spacing of neighbouring x intervals. (Previously I was using some simplified--but beautifully symmetric--equations that assumed equal spacing.) They are also now totally free of index updates, thanks to your concatenation trick ;)
The error now exactly matches scipy, but they have trouble where the sampling gets sparser toward the right near x=3. The function is very near zero there too, so maybe that makes it worse? I'll check on what happens now with the comoving distance calculation that we actually care about.
Fantastic! Thanks @austinpeel !
@all-contributors please add @austinpeel for code
I've put up a pull request to add @austinpeel! 🎉
Also you should open a PR on the google jax project for this :-)
So I have switched all interpolations to the new splines in branch u/EiffL/spline_integration
and it works great! Only one unexpected side effect appeared, the spline increases the compilation time of the code ^^'
Hi all, great work on these interpolation routines and thanks for sharing with the open-source community! Would you guys consider submitting a pull request to the main Jax repo with this?
These interpolation routines are very useful to scientists across many other fields (like me), so I'm sure your contribution would be appreciated! :)
Hi @peterdsharpe ! Yes that was our plan :-) but I think we got slightly sidetracked ^^'. Thanks for reaching out, this will give us the extra nodge we needed to acctually do it, it's great to hear that this will be useful to more people!
And in the meantime, we got these splines working in a self-contained file over here: https://github.com/DifferentiableUniverseInitiative/jax_cosmo/blob/master/jax_cosmo/scipy/interpolate.py
Great to hear @EiffL , looking forward to the Jax PR and thanks again! We're using your differentiable UnivariateSpline interpolator code to optimize magnet designs for fusion reactors, so suffice to say the use cases are diverse haha! :)
Hi! I was wondering if any progress had been made in getting these higher order interpolators directly into google/jax? This would be extremely convenient for my work. Let me know if I can do anything to help with this.