pyro-ppl/numpyro

Error in model using CircularReparam when trying to use Predictive

Closed this issue · 1 comments

Hello,

I'm trying to use the Predictive class with a model whose response variable is circular. I get a deep NotImplementedError when I try to use Predictive with the (recommended) circular reparameterization, but not when I don't use this.

Is this error related to not being able to use reparameterization on the observed variable (likelihood)? Or something specific to circular reparameterization? Any help much appreciated!

Minimal example:

import numpyro
import jax.numpy as jnp
from jax import random
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.handlers import reparam
from numpyro.infer.reparam import CircularReparam

x = jnp.linspace(-jnp.pi, jnp.pi)
y = x + random.normal(random.key(234), shape=x.shape)  # acknowledged that this will be out-of-bounds

def model(x, y=None):
    b = numpyro.sample("b", dist.Normal(0, 1))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    kappa = 1 / sigma**2
    numpyro.sample(
        "obs",
        dist.VonMises(loc=b * x, concentration=kappa),
        obs=y,
    )


reparam_model = reparam(model, config={"obs": CircularReparam()})

# Start from this source of randomness. We will split keys for subsequent operations.
rng_key = random.key(4159)
rng_key, rng_key_ = random.split(rng_key)

# Run NUTS.
kernel = NUTS(reparam_model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4)
mcmc.run(rng_key_, x=x, y=y)
mcmc.print_summary()
samples_1 = mcmc.get_samples()

# generate posterior predictions on fitted values
rng_key, rng_key_ = random.split(rng_key)
predictive = Predictive(reparam_model, samples_1)
predictions = predictive(rng_key_, x=x)["obs"]

Error trace:

{
	"name": "NotImplementedError",
	"message": "",
	"stack": "---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[78], line 4
      2 rng_key, rng_key_ = random.split(rng_key)
      3 predictive = Predictive(reparam_model, samples_1)
----> 4 predictions = predictive(rng_key_, x=x)[\"obs\"]
      5 df[\"Mean Predictions\"] = jnp.mean(predictions, axis=0)
      6 df.head()

File .../python3.12/site-packages/numpyro/infer/util.py:1037, in Predictive.__call__(self, rng_key, *args, **kwargs)
   1027 \"\"\"
   1028 Returns dict of samples from the predictive distribution. By default, only sample sites not
   1029 contained in `posterior_samples` are returned. This can be modified by changing the
   (...)
   1034 :param kwargs: model kwargs.
   1035 \"\"\"
   1036 if self.batch_ndims == 0 or self.params == {} or self.guide is None:
-> 1037     return self._call_with_params(rng_key, self.params, args, kwargs)
   1038 elif self.batch_ndims == 1:  # batch over parameters
   1039     batch_size = jnp.shape(jax.tree.flatten(self.params)[0][0])[0]

File .../python3.12/site-packages/numpyro/infer/util.py:1013, in Predictive._call_with_params(self, rng_key, params, args, kwargs)
   1001     posterior_samples = _predictive(
   1002         guide_rng_key,
   1003         guide,
   (...)
   1010         exclude_deterministic=self.exclude_deterministic,
   1011     )
   1012 model = substitute(self.model, self.params)
-> 1013 return _predictive(
   1014     rng_key,
   1015     model,
   1016     posterior_samples,
   1017     self._batch_shape,
   1018     return_sites=self.return_sites,
   1019     infer_discrete=self.infer_discrete,
   1020     parallel=self.parallel,
   1021     model_args=args,
   1022     model_kwargs=kwargs,
   1023     exclude_deterministic=self.exclude_deterministic,
   1024 )

File .../python3.12/site-packages/numpyro/infer/util.py:846, in _predictive(rng_key, model, posterior_samples, batch_shape, return_sites, infer_discrete, parallel, exclude_deterministic, model_args, model_kwargs)
    844 rng_key = rng_key.reshape(batch_shape + key_shape)
    845 chunk_size = num_samples if parallel else 1
--> 846 return soft_vmap(
    847     single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size
    848 )

File .../python3.12/site-packages/numpyro/util.py:452, in soft_vmap(fn, xs, batch_ndims, chunk_size)
    446     xs = jax.tree.map(
    447         lambda x: jnp.reshape(x, prepend_shape + (chunk_size,) + jnp.shape(x)[1:]),
    448         xs,
    449     )
    450     fn = vmap(fn)
--> 452 ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs)
    453 map_ndims = int(num_chunks > 1) + int(chunk_size > 1)
    454 ys = jax.tree.map(
    455     lambda y: jnp.reshape(
    456         y, (int(np.prod(jnp.shape(y)[:map_ndims])),) + jnp.shape(y)[map_ndims:]
    457     )[:batch_size],
    458     ys,
    459 )

    [... skipping hidden 12 frame]

File .../python3.12/site-packages/numpyro/infer/util.py:819, in _predictive.<locals>.single_prediction(val)
    810         return (
    811             samples.get(msg[\"name\"]) if msg[\"type\"] != \"deterministic\" else None
    812         )
    814     substituted_model = (
    815         substitute(masked_model, substitute_fn=_samples_wo_deterministic)
    816         if exclude_deterministic
    817         else substitute(masked_model, samples)
    818     )
--> 819     model_trace = trace(seed(substituted_model, rng_key)).get_trace(
    820         *model_args, **model_kwargs
    821     )
    822     pred_samples = {name: site[\"value\"] for name, site in model_trace.items()}
    824 if return_sites is not None:

File .../python3.12/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
    163 def get_trace(self, *args, **kwargs):
    164     \"\"\"
    165     Run the wrapped callable and return the recorded trace.
    166 
   (...)
    169     :return: `OrderedDict` containing the execution trace.
    170     \"\"\"
--> 171     self(*args, **kwargs)
    172     return self.trace

File .../python3.12/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

File .../python3.12/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

    [... skipping similar frames: Messenger.__call__ at line 105 (3 times)]

File .../python3.12/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

Cell In[72], line 5, in model(x, y)
      3 sigma = numpyro.sample(\"sigma\", dist.Exponential(1))
      4 kappa = 1 / sigma**2
----> 5 numpyro.sample(
      6     \"obs\",
      7     dist.VonMises(loc=b * x, concentration=kappa),
      8     obs=y,
      9 )

File .../python3.12/site-packages/numpyro/primitives.py:222, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
    207 initial_msg = {
    208     \"type\": \"sample\",
    209     \"name\": name,
   (...)
    218     \"infer\": {} if infer is None else infer,
    219 }
    221 # ...and use apply_stack to send it to the Messengers
--> 222 msg = apply_stack(initial_msg)
    223 return msg[\"value\"]

File .../python3.12/site-packages/numpyro/primitives.py:47, in apply_stack(msg)
     45 pointer = 0
     46 for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 47     handler.process_message(msg)
     48     # When a Messenger sets the \"stop\" field of a message,
     49     # it prevents any Messengers above it on the stack from being applied.
     50     if msg.get(\"stop\"):

File .../python3.12/site-packages/numpyro/handlers.py:583, in reparam.process_message(self, msg)
    580 if reparam is None:
    581     return
--> 583 new_fn, value = reparam(msg[\"name\"], msg[\"fn\"], msg[\"value\"])
    585 if value is not None:
    586     if new_fn is None:

File .../python3.12/site-packages/numpyro/infer/reparam.py:344, in CircularReparam.__call__(self, name, fn, obs)
    342 # Draw parameter-free noise.
    343 new_fn = dist.ImproperUniform(constraints.real, fn.batch_shape, fn.event_shape)
--> 344 value = numpyro.sample(
    345     f\"{name}_unwrapped\",
    346     new_fn,
    347     obs=obs,
    348 )
    350 # Differentiably transform.
    351 value = jnp.remainder(value + math.pi, 2 * math.pi) - math.pi

File .../python3.12/site-packages/numpyro/primitives.py:222, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
    207 initial_msg = {
    208     \"type\": \"sample\",
    209     \"name\": name,
   (...)
    218     \"infer\": {} if infer is None else infer,
    219 }
    221 # ...and use apply_stack to send it to the Messengers
--> 222 msg = apply_stack(initial_msg)
    223 return msg[\"value\"]

File .../python3.12/site-packages/numpyro/primitives.py:53, in apply_stack(msg)
     50     if msg.get(\"stop\"):
     51         break
---> 53 default_process_message(msg)
     55 # A Messenger that sets msg[\"stop\"] == True also prevents application
     56 # of postprocess_message by Messengers above it on the stack
     57 # via the pointer variable from the process_message loop
     58 for handler in _PYRO_STACK[-pointer - 1 :]:

File .../python3.12/site-packages/numpyro/primitives.py:24, in default_process_message(msg)
     22 if msg[\"value\"] is None:
     23     if msg[\"type\"] == \"sample\":
---> 24         msg[\"value\"], msg[\"intermediates\"] = msg[\"fn\"](
     25             *msg[\"args\"], sample_intermediates=True, **msg[\"kwargs\"]
     26         )
     27     else:
     28         msg[\"value\"] = msg[\"fn\"](*msg[\"args\"], **msg[\"kwargs\"])

File .../python3.12/site-packages/numpyro/distributions/distribution.py:369, in Distribution.__call__(self, *args, **kwargs)
    367 sample_intermediates = kwargs.pop(\"sample_intermediates\", False)
    368 if sample_intermediates:
--> 369     return self.sample_with_intermediates(key, *args, **kwargs)
    370 return self.sample(key, *args, **kwargs)

File .../python3.12/site-packages/numpyro/distributions/distribution.py:327, in Distribution.sample_with_intermediates(self, key, sample_shape)
    317 def sample_with_intermediates(self, key, sample_shape=()):
    318     \"\"\"
    319     Same as ``sample`` except that any intermediate computations are
    320     returned (useful for `TransformedDistribution`).
   (...)
    325     :rtype: numpy.ndarray
    326     \"\"\"
--> 327     return self.sample(key, sample_shape=sample_shape), []

File .../python3.12/site-packages/numpyro/distributions/distribution.py:909, in MaskedDistribution.sample(self, key, sample_shape)
    908 def sample(self, key, sample_shape=()):
--> 909     return self.base_dist(rng_key=key, sample_shape=sample_shape)

File .../python3.12/site-packages/numpyro/distributions/distribution.py:370, in Distribution.__call__(self, *args, **kwargs)
    368 if sample_intermediates:
    369     return self.sample_with_intermediates(key, *args, **kwargs)
--> 370 return self.sample(key, *args, **kwargs)

File .../python3.12/site-packages/numpyro/distributions/distribution.py:315, in Distribution.sample(self, key, sample_shape)
    303 def sample(self, key, sample_shape=()):
    304     \"\"\"
    305     Returns a sample from the distribution having shape given by
    306     `sample_shape + batch_shape + event_shape`. Note that when `sample_shape` is non-empty,
   (...)
    313     :rtype: numpy.ndarray
    314     \"\"\"
--> 315     raise NotImplementedError

NotImplementedError: "
}

Note there's no error if you replace reparam_model with model above.

Thanks in advance for input!

Hmm, I think we dont need to reparam the likelihood. We should add an assertion, like other reparam, to disallow that usage.