DifferentiableUniverseInitiative/jax_cosmo

Notes for improvements

EiffL opened this issue · 5 comments

EiffL commented

Discussing with @eelregit here are few ideas of things to improve:

  • Allow parameterisation in terms of As
  • Allow for flattening of the cosmology object
  • Switch to jax.numpy.interp !
  • Try to use jax.experimental.odeint instead of jax_cosmo.ode
  • Configuration parameters stored in cosmo structure
  • include_logdet flag in gaussian_log_likelihood is reversed
  • Not sure if transverse_comoving_distance is actually jittable

I was testing using custom pytree aux_data to store config parameters but saw some very strange behavior:

from jax import jit
from jax.tree_util import register_pytree_node_class

@register_pytree_node_class
class C:
    def __init__(self, p, config={'a': 0}):
        self.p = p
        self._config = config

    def tree_flatten(self):
        children = (self.p,)
        aux_data = self._config
        return children, aux_data

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        c = cls(*children)
        c._config.update(aux_data)
        return c
    
def f(c):
    c.p = 1. + c.p
    return c

g = jit(f)

(
    vars(C(0.)),
    vars(f(C(0.))),
    vars(f(C(0., {'a': 1}))),
    vars(g(C(0., {'b': 2}))),
)

returning

({'p': 0.0, '_config': {'a': 0, 'b': 2}},
 {'p': 1.0, '_config': {'a': 0, 'b': 2}},
 {'p': 1.0, '_config': {'a': 1}},
 {'p': DeviceArray(1., dtype=float32, weak_type=True), '_config': {'a': 0, 'b': 2}})

Since pytree docs are pretty incomplete, I am still worried if including configuration parameters as aux_data is intended by JAX devs.
I should probably open an issue there.

EiffL commented

hummmmmm not sure I see the problem? What result where you expecting for the jitted function?

E.g., where are the 'b': 2 entries (except the last one) coming from?

I am still wondering: are there some examples where optimizing sigma_8 instead of A_s is better?

I feel the sigma_8 coordinate system would make an optimizer focus too much on the 8 Mpc/h scale, in a way that all parameters are adjusted to prioritize on that agreement. So it doesn't feel natural and may not be a good default in general?

Okay, that was just me not knowing one should not use mutable default arguments in python.... 😅