Potential bug using scan to create random variable in PyMC
junpenglao opened this issue · 6 comments
Currently, a simple random walk example fails for me:
import numpy as np
import pymc as pm
import matplotlib.pyplot as plt
import aesara
import aesara.tensor as at
import aeppl
num_timesteps = 100
data = np.random.normal(0, 2.5, size=num_timesteps).cumsum()
plt.plot(data);
with pm.Model() as m:
sigma = pm.HalfNormal("sigma", 5.)
mu = pm.Normal("mu", 0., 1.)
X_rv, updates = aesara.scan(
fn=lambda x_tm1: at.random.normal(x_tm1, sigma),
outputs_info=[{"initial": mu}],
n_steps=num_timesteps
)
m.register_rv(X_rv, name="X_rv", data=data)
# X_rv = pm.GaussianRandomWalk("X_rv", mu, sigma, observed=data)
idata = pm.sample()
with:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
/var/folders/7p/srk5qjp563l5f9mrjtp44bh800jqsw/T/ipykernel_33625/1955773877.py in <module>
9 m.register_rv(X_rv, name="X_rv", data=data)
10 # X_rv = pm.GaussianRandomWalk("X_rv", mu, sigma, observed=data)
---> 11 idata = pm.sample()
~/Documents/OSS/pymc/pymc/sampling.py in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
528
529 initial_points = None
--> 530 step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
531
532 if isinstance(step, list):
~/Documents/OSS/pymc/pymc/sampling.py in assign_step_methods(model, step, methods, step_kwargs)
204 # variables
205 selected_steps = defaultdict(list)
--> 206 model_logp = model.logp()
207
208 for var in model.value_vars:
~/Documents/OSS/pymc/pymc/model.py in logp(self, vars, jacobian, sum)
756 rv_logps: List[TensorVariable] = []
757 if rv_values:
--> 758 rv_logps = joint_logp(list(rv_values.keys()), rv_values, sum=False, jacobian=jacobian)
759 assert isinstance(rv_logps, list)
760
~/Documents/OSS/pymc/pymc/distributions/logprob.py in joint_logp(var, rv_values, jacobian, scaling, transformed, sum, **kwargs)
269 logp_var_dict = {}
270 for value_var in rv_values.values():
--> 271 logp_var_dict[value_var] = temp_logp_var_dict[value_var]
272
273 if scaling:
KeyError: X_rv{[-1.502085..27194e+01]}
Inspecting with aeppl seems to indicate it does not recognize the RV result from scan:
x_vv = at.constant(data)
mu_vv = mu.clone()
sigma_vv = sigma.clone()
logp_dict = aeppl.factorized_joint_logprob({X_rv: x_vv, mu: mu_vv, sigma: sigma_vv})
logp_dict
# ==> {mu: mu_logprob, sigma: sigma_logprob}
cc @ricardoV94
Sounds like the problem is simply not passing sigma
as a non_sequence
. Probably Aeppl requires this information: https://colab.research.google.com/drive/1Yvh0VpZE4Bhtu5-mL4qXnlK0-j45ARBF#scrollTo=WPVbMVisw8lM
Thanks @ricardoV94! Besides passing sigma
as non_sequence
, it is also important to have {x.owner.inputs[0]: x.owner.outputs[0]}
in the return.
Thanks @ricardoV94! Besides passing
sigma
asnon_sequence
, it is also important to have{x.owner.inputs[0]: x.owner.outputs[0]}
in the return.
Only important for forward sampling I think. If you use RandomStream
to create the RVs inside scan it all happens behind the scenes.
Ah, the gradient (but not the logp) raises an error if you have a scan without the explicit updates... strange.
sigma_rv = at.random.halfnormal(0, 5.0, name="sigma")
mu_rv = at.random.normal(0, 1.0, name="mu")
def step(x_tm1, sigma_rv):
x = at.random.normal(x_tm1, sigma_rv)
return x #, {x.owner.inputs[0]: x.owner.outputs[0]}
scan_rv, updates = aesara.scan(
fn=step,
outputs_info=[{"initial": mu_rv}],
n_steps=num_timesteps,
non_sequences=[sigma_rv],
)
scan_rv.name = "scan"
sigma_vv = sigma_rv.clone()
mu_vv = mu_rv.clone()
scan_vv = scan_rv.clone()
logp_dict = aeppl.factorized_joint_logprob({
sigma_rv: sigma_vv,
mu_rv: mu_vv,
scan_rv: scan_vv,
})
# The next line raises
at.grad(logp_dict[scan_vv].sum(), wrt=sigma_vv)
TypeError: Tensor type field must be a TensorType; found <class 'aesara.tensor.random.type.RandomGeneratorType'>.
@junpenglao Does this still fail? If it works by modifying the original example would you mind sharing the modified version?
Yes, this is the working implementation:
with pm.Model(coords={"timestep": np.arange(num_timesteps)}) as m:
sigma = pm.HalfNormal("sigma", 5.0)
mu = pm.Normal("mu", 0.0, 1.0)
def step(x_tm1, sigma):
x = pm.Normal.dist(x_tm1, sigma)
# x = at.random.normal(x_tm1, sigma)
# Return the new variable and the RNG update expression
return x, {x.owner.inputs[0]: x.owner.outputs[0]}
X_rv, updates = aesara.scan(
fn=step, outputs_info=[mu], non_sequences=[sigma], n_steps=num_timesteps
)
m.register_rv(X_rv, name="X_rv", data=data)
# X_rv = pm.GaussianRandomWalk("X_rv", mu, sigma, observed=data)
idata = pm.sample()