Error in model using CircularReparam when trying to use Predictive
Closed this issue · 1 comments
tomwallis commented
Hello,
I'm trying to use the Predictive
class with a model whose response variable is circular. I get a deep NotImplementedError
when I try to use Predictive
with the (recommended) circular reparameterization, but not when I don't use this.
Is this error related to not being able to use reparameterization on the observed variable (likelihood)? Or something specific to circular reparameterization? Any help much appreciated!
Minimal example:
import numpyro
import jax.numpy as jnp
from jax import random
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.handlers import reparam
from numpyro.infer.reparam import CircularReparam
x = jnp.linspace(-jnp.pi, jnp.pi)
y = x + random.normal(random.key(234), shape=x.shape) # acknowledged that this will be out-of-bounds
def model(x, y=None):
b = numpyro.sample("b", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.Exponential(1))
kappa = 1 / sigma**2
numpyro.sample(
"obs",
dist.VonMises(loc=b * x, concentration=kappa),
obs=y,
)
reparam_model = reparam(model, config={"obs": CircularReparam()})
# Start from this source of randomness. We will split keys for subsequent operations.
rng_key = random.key(4159)
rng_key, rng_key_ = random.split(rng_key)
# Run NUTS.
kernel = NUTS(reparam_model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4)
mcmc.run(rng_key_, x=x, y=y)
mcmc.print_summary()
samples_1 = mcmc.get_samples()
# generate posterior predictions on fitted values
rng_key, rng_key_ = random.split(rng_key)
predictive = Predictive(reparam_model, samples_1)
predictions = predictive(rng_key_, x=x)["obs"]
Error trace:
{
"name": "NotImplementedError",
"message": "",
"stack": "---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Cell In[78], line 4
2 rng_key, rng_key_ = random.split(rng_key)
3 predictive = Predictive(reparam_model, samples_1)
----> 4 predictions = predictive(rng_key_, x=x)[\"obs\"]
5 df[\"Mean Predictions\"] = jnp.mean(predictions, axis=0)
6 df.head()
File .../python3.12/site-packages/numpyro/infer/util.py:1037, in Predictive.__call__(self, rng_key, *args, **kwargs)
1027 \"\"\"
1028 Returns dict of samples from the predictive distribution. By default, only sample sites not
1029 contained in `posterior_samples` are returned. This can be modified by changing the
(...)
1034 :param kwargs: model kwargs.
1035 \"\"\"
1036 if self.batch_ndims == 0 or self.params == {} or self.guide is None:
-> 1037 return self._call_with_params(rng_key, self.params, args, kwargs)
1038 elif self.batch_ndims == 1: # batch over parameters
1039 batch_size = jnp.shape(jax.tree.flatten(self.params)[0][0])[0]
File .../python3.12/site-packages/numpyro/infer/util.py:1013, in Predictive._call_with_params(self, rng_key, params, args, kwargs)
1001 posterior_samples = _predictive(
1002 guide_rng_key,
1003 guide,
(...)
1010 exclude_deterministic=self.exclude_deterministic,
1011 )
1012 model = substitute(self.model, self.params)
-> 1013 return _predictive(
1014 rng_key,
1015 model,
1016 posterior_samples,
1017 self._batch_shape,
1018 return_sites=self.return_sites,
1019 infer_discrete=self.infer_discrete,
1020 parallel=self.parallel,
1021 model_args=args,
1022 model_kwargs=kwargs,
1023 exclude_deterministic=self.exclude_deterministic,
1024 )
File .../python3.12/site-packages/numpyro/infer/util.py:846, in _predictive(rng_key, model, posterior_samples, batch_shape, return_sites, infer_discrete, parallel, exclude_deterministic, model_args, model_kwargs)
844 rng_key = rng_key.reshape(batch_shape + key_shape)
845 chunk_size = num_samples if parallel else 1
--> 846 return soft_vmap(
847 single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size
848 )
File .../python3.12/site-packages/numpyro/util.py:452, in soft_vmap(fn, xs, batch_ndims, chunk_size)
446 xs = jax.tree.map(
447 lambda x: jnp.reshape(x, prepend_shape + (chunk_size,) + jnp.shape(x)[1:]),
448 xs,
449 )
450 fn = vmap(fn)
--> 452 ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs)
453 map_ndims = int(num_chunks > 1) + int(chunk_size > 1)
454 ys = jax.tree.map(
455 lambda y: jnp.reshape(
456 y, (int(np.prod(jnp.shape(y)[:map_ndims])),) + jnp.shape(y)[map_ndims:]
457 )[:batch_size],
458 ys,
459 )
[... skipping hidden 12 frame]
File .../python3.12/site-packages/numpyro/infer/util.py:819, in _predictive.<locals>.single_prediction(val)
810 return (
811 samples.get(msg[\"name\"]) if msg[\"type\"] != \"deterministic\" else None
812 )
814 substituted_model = (
815 substitute(masked_model, substitute_fn=_samples_wo_deterministic)
816 if exclude_deterministic
817 else substitute(masked_model, samples)
818 )
--> 819 model_trace = trace(seed(substituted_model, rng_key)).get_trace(
820 *model_args, **model_kwargs
821 )
822 pred_samples = {name: site[\"value\"] for name, site in model_trace.items()}
824 if return_sites is not None:
File .../python3.12/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
163 def get_trace(self, *args, **kwargs):
164 \"\"\"
165 Run the wrapped callable and return the recorded trace.
166
(...)
169 :return: `OrderedDict` containing the execution trace.
170 \"\"\"
--> 171 self(*args, **kwargs)
172 return self.trace
File .../python3.12/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
File .../python3.12/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
[... skipping similar frames: Messenger.__call__ at line 105 (3 times)]
File .../python3.12/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
Cell In[72], line 5, in model(x, y)
3 sigma = numpyro.sample(\"sigma\", dist.Exponential(1))
4 kappa = 1 / sigma**2
----> 5 numpyro.sample(
6 \"obs\",
7 dist.VonMises(loc=b * x, concentration=kappa),
8 obs=y,
9 )
File .../python3.12/site-packages/numpyro/primitives.py:222, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
207 initial_msg = {
208 \"type\": \"sample\",
209 \"name\": name,
(...)
218 \"infer\": {} if infer is None else infer,
219 }
221 # ...and use apply_stack to send it to the Messengers
--> 222 msg = apply_stack(initial_msg)
223 return msg[\"value\"]
File .../python3.12/site-packages/numpyro/primitives.py:47, in apply_stack(msg)
45 pointer = 0
46 for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 47 handler.process_message(msg)
48 # When a Messenger sets the \"stop\" field of a message,
49 # it prevents any Messengers above it on the stack from being applied.
50 if msg.get(\"stop\"):
File .../python3.12/site-packages/numpyro/handlers.py:583, in reparam.process_message(self, msg)
580 if reparam is None:
581 return
--> 583 new_fn, value = reparam(msg[\"name\"], msg[\"fn\"], msg[\"value\"])
585 if value is not None:
586 if new_fn is None:
File .../python3.12/site-packages/numpyro/infer/reparam.py:344, in CircularReparam.__call__(self, name, fn, obs)
342 # Draw parameter-free noise.
343 new_fn = dist.ImproperUniform(constraints.real, fn.batch_shape, fn.event_shape)
--> 344 value = numpyro.sample(
345 f\"{name}_unwrapped\",
346 new_fn,
347 obs=obs,
348 )
350 # Differentiably transform.
351 value = jnp.remainder(value + math.pi, 2 * math.pi) - math.pi
File .../python3.12/site-packages/numpyro/primitives.py:222, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
207 initial_msg = {
208 \"type\": \"sample\",
209 \"name\": name,
(...)
218 \"infer\": {} if infer is None else infer,
219 }
221 # ...and use apply_stack to send it to the Messengers
--> 222 msg = apply_stack(initial_msg)
223 return msg[\"value\"]
File .../python3.12/site-packages/numpyro/primitives.py:53, in apply_stack(msg)
50 if msg.get(\"stop\"):
51 break
---> 53 default_process_message(msg)
55 # A Messenger that sets msg[\"stop\"] == True also prevents application
56 # of postprocess_message by Messengers above it on the stack
57 # via the pointer variable from the process_message loop
58 for handler in _PYRO_STACK[-pointer - 1 :]:
File .../python3.12/site-packages/numpyro/primitives.py:24, in default_process_message(msg)
22 if msg[\"value\"] is None:
23 if msg[\"type\"] == \"sample\":
---> 24 msg[\"value\"], msg[\"intermediates\"] = msg[\"fn\"](
25 *msg[\"args\"], sample_intermediates=True, **msg[\"kwargs\"]
26 )
27 else:
28 msg[\"value\"] = msg[\"fn\"](*msg[\"args\"], **msg[\"kwargs\"])
File .../python3.12/site-packages/numpyro/distributions/distribution.py:369, in Distribution.__call__(self, *args, **kwargs)
367 sample_intermediates = kwargs.pop(\"sample_intermediates\", False)
368 if sample_intermediates:
--> 369 return self.sample_with_intermediates(key, *args, **kwargs)
370 return self.sample(key, *args, **kwargs)
File .../python3.12/site-packages/numpyro/distributions/distribution.py:327, in Distribution.sample_with_intermediates(self, key, sample_shape)
317 def sample_with_intermediates(self, key, sample_shape=()):
318 \"\"\"
319 Same as ``sample`` except that any intermediate computations are
320 returned (useful for `TransformedDistribution`).
(...)
325 :rtype: numpy.ndarray
326 \"\"\"
--> 327 return self.sample(key, sample_shape=sample_shape), []
File .../python3.12/site-packages/numpyro/distributions/distribution.py:909, in MaskedDistribution.sample(self, key, sample_shape)
908 def sample(self, key, sample_shape=()):
--> 909 return self.base_dist(rng_key=key, sample_shape=sample_shape)
File .../python3.12/site-packages/numpyro/distributions/distribution.py:370, in Distribution.__call__(self, *args, **kwargs)
368 if sample_intermediates:
369 return self.sample_with_intermediates(key, *args, **kwargs)
--> 370 return self.sample(key, *args, **kwargs)
File .../python3.12/site-packages/numpyro/distributions/distribution.py:315, in Distribution.sample(self, key, sample_shape)
303 def sample(self, key, sample_shape=()):
304 \"\"\"
305 Returns a sample from the distribution having shape given by
306 `sample_shape + batch_shape + event_shape`. Note that when `sample_shape` is non-empty,
(...)
313 :rtype: numpy.ndarray
314 \"\"\"
--> 315 raise NotImplementedError
NotImplementedError: "
}
Note there's no error if you replace reparam_model
with model
above.
Thanks in advance for input!
fehiepsi commented
Hmm, I think we dont need to reparam the likelihood. We should add an assertion, like other reparam, to disallow that usage.