Notes for improvements
EiffL opened this issue · 5 comments
Discussing with @eelregit here are few ideas of things to improve:
- Allow parameterisation in terms of As
- Allow for flattening of the cosmology object
- Switch to jax.numpy.interp !
- Try to use jax.experimental.odeint instead of jax_cosmo.ode
- Configuration parameters stored in cosmo structure
-
include_logdet
flag in gaussian_log_likelihood is reversed - Not sure if
transverse_comoving_distance
is actually jittable
I was testing using custom pytree aux_data to store config parameters but saw some very strange behavior:
from jax import jit
from jax.tree_util import register_pytree_node_class
@register_pytree_node_class
class C:
def __init__(self, p, config={'a': 0}):
self.p = p
self._config = config
def tree_flatten(self):
children = (self.p,)
aux_data = self._config
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
c = cls(*children)
c._config.update(aux_data)
return c
def f(c):
c.p = 1. + c.p
return c
g = jit(f)
(
vars(C(0.)),
vars(f(C(0.))),
vars(f(C(0., {'a': 1}))),
vars(g(C(0., {'b': 2}))),
)
returning
({'p': 0.0, '_config': {'a': 0, 'b': 2}},
{'p': 1.0, '_config': {'a': 0, 'b': 2}},
{'p': 1.0, '_config': {'a': 1}},
{'p': DeviceArray(1., dtype=float32, weak_type=True), '_config': {'a': 0, 'b': 2}})
Since pytree docs are pretty incomplete, I am still worried if including configuration parameters as aux_data
is intended by JAX devs.
I should probably open an issue there.
hummmmmm not sure I see the problem? What result where you expecting for the jitted function?
E.g., where are the 'b': 2
entries (except the last one) coming from?
I am still wondering: are there some examples where optimizing sigma_8
instead of A_s
is better?
I feel the sigma_8
coordinate system would make an optimizer focus too much on the 8 Mpc/h scale, in a way that all parameters are adjusted to prioritize on that agreement. So it doesn't feel natural and may not be a good default in general?
Okay, that was just me not knowing one should not use mutable default arguments in python.... 😅