Sampling from arrays
ColCarroll opened this issue · 0 comments
ColCarroll commented
It looks as though flowMC
fails to sample when the initial point is either 1 dimensional, or has size 1.
Consider sampling from a multivariate normal:
import jax
import jax.numpy as jnp
from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline
from flowMC.sampler.MALA import MALA
from flowMC.sampler.HMC import HMC
from flowMC.sampler.Sampler import Sampler
from flowMC.utils.PRNG_keys import initialize_rng_keys
n_chains=10
def log_density(x, data):
return jnp.sum(-x**2)
n_dim = 2
rng_key_set = initialize_rng_keys(n_chains, seed=42)
model = MaskedCouplingRQSpline(
n_layers=3, hidden_size=[64, 64], num_bins=8, n_features=n_dim,
key=jax.random.PRNGKey(21))
local_sampler = MALA(log_density, True, params={"step_size": 0.1})
nf_sampler = Sampler(
# added the n_loop_training and n_loop_production
n_dim=n_dim,
rng_key_set=rng_key_set,
data={},
local_sampler=local_sampler,
nf_model=model,
n_chains=n_chains)
- If we use
nf_sampler.sample(1., {})
the following is thrown:
ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())
- If we use:
nf_sampler.sample(jnp.ones(n_dim), {}) # where n_dim > 1
the following is thrown:
ValueError: Incompatible shapes for broadcasting: shapes=[(10, 50, 2), (2, 1)]
- If we use
nf_sampler.sample(jnp.atleast_2d(jnp.ones(n_dim)), {}) # where n_dim > 1
the code will run.
In case n_dim = 1
, last two cases also fail with
ValueError: diag input must be 1d or 2d
I'm not sure if this is intended behavior, but figure I'd flag those!