vmap produces wrong results silently
JTaets opened this issue · 2 comments
JTaets commented
rng = jax.random.PRNGKey(0)
loc = jax.random.uniform(rng,(4,1))
logscale = jnp.array([1.])
pi1 = distrax.MultivariateNormalDiag(loc, jnp.exp(logscale))
def apply_dist(x,logscale):
return distrax.MultivariateNormalDiag(x, jnp.exp(logscale))
pi2 = jax.vmap(apply_dist, (0,None))(loc, jnp.exp(logscale))
print(pi1.sample(seed=rng).shape) #(4,1)
print(pi2.sample(seed=rng).shape) #(4,4)
JTaets commented
I suppose that the sample inside the vmap function fixes this
JTaets commented
But when requiring the entropy of the full distribution, the results are wrong