DifferentiableUniverseInitiative/jax_cosmo

interp(x, xp, fp)

Opened this issue · 2 comments

Hi,
jax.numpy has an interp method since August 2020, with the same API, so I guess we can switch to this JAX implementation instead of https://github.com/DifferentiableUniverseInitiative/jax_cosmo/blob/master/jax_cosmo/scipy/interpolate.py

In def radial_comoving_distance ( background.py)

240:  a = np.atleast_1d(a)

is not useful using jnp.interp and more it prevents to make a grad.

Maybe asarray is enough.