DifferentiableUniverseInitiative/jax_cosmo

Functional or Object API for background quantities?

EiffL opened this issue · 2 comments

EiffL commented

Currently, functions like H(cosmo, a), radial_comoving_distance(cosmo, a), etc... which live in background.py are using a functional API. We could instead have these functions be methods of a Cosmology object.

The question is... which interface is better:

chi = bkgr.radial_comoving_distance(cosmo, a)

or

chi = cosmo.radial_comoving_distance(a)

Probably the second one...... But I'm asking just in case some people have some thoughts on this before switching

I think the second option is the best one. It's more compatible to what python wrappers for Boltzmann codes are doing. Also the cosmological parameters basically define the background, so it makes more logical sense.

EiffL commented

I would agree, except, it prevents you from differentiating with respect to cosmology without an auxiliary function.
with a functional API you can do:

jax.grad(bkgr.radial_comoving_distance)(cosmo, a)

And this will return a gradient of the cosmology, because the cosmology is the argument of the function.
Otherwise, in an object API, you have to do:

def fn(cosmo):
  return cosmo.radial_comoving_distance(a)
grad(fn)(cosmo)

I guess you can do in a shorter form... grad( lambda x: x.radial_comoving_distance(x,a))(cosmo) but yeah that's the main problem, the object API works against JAX