google-deepmind/distrax

vmap produces wrong results silently

JTaets opened this issue · 2 comments

JTaets commented
rng = jax.random.PRNGKey(0)
loc = jax.random.uniform(rng,(4,1))
logscale = jnp.array([1.])

pi1 = distrax.MultivariateNormalDiag(loc, jnp.exp(logscale))
def apply_dist(x,logscale):
    return distrax.MultivariateNormalDiag(x, jnp.exp(logscale))
pi2 = jax.vmap(apply_dist, (0,None))(loc, jnp.exp(logscale))

print(pi1.sample(seed=rng).shape)  #(4,1)
print(pi2.sample(seed=rng).shape)  #(4,4)
JTaets commented

I suppose that the sample inside the vmap function fixes this

JTaets commented

But when requiring the entropy of the full distribution, the results are wrong