aesara-devs/aeppl

Potential bug using scan to create random variable in PyMC

junpenglao opened this issue · 6 comments

Currently, a simple random walk example fails for me:

import numpy as np
import pymc as pm
import matplotlib.pyplot as plt

import aesara
import aesara.tensor as at

import aeppl

num_timesteps = 100
data = np.random.normal(0, 2.5, size=num_timesteps).cumsum()
plt.plot(data);

with pm.Model() as m:
    sigma = pm.HalfNormal("sigma", 5.)
    mu = pm.Normal("mu", 0., 1.)
    X_rv, updates = aesara.scan(
        fn=lambda x_tm1: at.random.normal(x_tm1, sigma),
        outputs_info=[{"initial": mu}],
        n_steps=num_timesteps
        )
    m.register_rv(X_rv, name="X_rv", data=data)
    # X_rv = pm.GaussianRandomWalk("X_rv", mu, sigma, observed=data)
    idata = pm.sample()

with:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
/var/folders/7p/srk5qjp563l5f9mrjtp44bh800jqsw/T/ipykernel_33625/1955773877.py in <module>
      9     m.register_rv(X_rv, name="X_rv", data=data)
     10     # X_rv = pm.GaussianRandomWalk("X_rv", mu, sigma, observed=data)
---> 11     idata = pm.sample()

~/Documents/OSS/pymc/pymc/sampling.py in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
    528 
    529     initial_points = None
--> 530     step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
    531 
    532     if isinstance(step, list):

~/Documents/OSS/pymc/pymc/sampling.py in assign_step_methods(model, step, methods, step_kwargs)
    204     # variables
    205     selected_steps = defaultdict(list)
--> 206     model_logp = model.logp()
    207 
    208     for var in model.value_vars:

~/Documents/OSS/pymc/pymc/model.py in logp(self, vars, jacobian, sum)
    756         rv_logps: List[TensorVariable] = []
    757         if rv_values:
--> 758             rv_logps = joint_logp(list(rv_values.keys()), rv_values, sum=False, jacobian=jacobian)
    759             assert isinstance(rv_logps, list)
    760 

~/Documents/OSS/pymc/pymc/distributions/logprob.py in joint_logp(var, rv_values, jacobian, scaling, transformed, sum, **kwargs)
    269     logp_var_dict = {}
    270     for value_var in rv_values.values():
--> 271         logp_var_dict[value_var] = temp_logp_var_dict[value_var]
    272 
    273     if scaling:

KeyError: X_rv{[-1.502085..27194e+01]}

Inspecting with aeppl seems to indicate it does not recognize the RV result from scan:

x_vv = at.constant(data)
mu_vv = mu.clone()
sigma_vv = sigma.clone()

logp_dict = aeppl.factorized_joint_logprob({X_rv: x_vv, mu: mu_vv, sigma: sigma_vv})
logp_dict
# ==> {mu: mu_logprob, sigma: sigma_logprob}

cc @ricardoV94

Sounds like the problem is simply not passing sigma as a non_sequence. Probably Aeppl requires this information: https://colab.research.google.com/drive/1Yvh0VpZE4Bhtu5-mL4qXnlK0-j45ARBF#scrollTo=WPVbMVisw8lM

Thanks @ricardoV94! Besides passing sigma as non_sequence, it is also important to have {x.owner.inputs[0]: x.owner.outputs[0]} in the return.

Thanks @ricardoV94! Besides passing sigma as non_sequence, it is also important to have {x.owner.inputs[0]: x.owner.outputs[0]} in the return.

Only important for forward sampling I think. If you use RandomStream to create the RVs inside scan it all happens behind the scenes.

Ah, the gradient (but not the logp) raises an error if you have a scan without the explicit updates... strange.

sigma_rv = at.random.halfnormal(0, 5.0, name="sigma")
mu_rv = at.random.normal(0, 1.0, name="mu")

def step(x_tm1, sigma_rv):
    x = at.random.normal(x_tm1, sigma_rv)
    return x #, {x.owner.inputs[0]: x.owner.outputs[0]}

scan_rv, updates = aesara.scan(
    fn=step,
    outputs_info=[{"initial": mu_rv}],
    n_steps=num_timesteps,
    non_sequences=[sigma_rv],
)
scan_rv.name = "scan"

sigma_vv = sigma_rv.clone()
mu_vv = mu_rv.clone()
scan_vv = scan_rv.clone()

logp_dict = aeppl.factorized_joint_logprob({
    sigma_rv: sigma_vv,
    mu_rv: mu_vv,
    scan_rv: scan_vv,
})

# The next line raises
at.grad(logp_dict[scan_vv].sum(), wrt=sigma_vv)
TypeError: Tensor type field must be a TensorType; found <class 'aesara.tensor.random.type.RandomGeneratorType'>.
rlouf commented

@junpenglao Does this still fail? If it works by modifying the original example would you mind sharing the modified version?

Yes, this is the working implementation:

with pm.Model(coords={"timestep": np.arange(num_timesteps)}) as m:
    sigma = pm.HalfNormal("sigma", 5.0)
    mu = pm.Normal("mu", 0.0, 1.0)

    def step(x_tm1, sigma):
        x = pm.Normal.dist(x_tm1, sigma)
        # x = at.random.normal(x_tm1, sigma)
        # Return the new variable and the RNG update expression
        return x, {x.owner.inputs[0]: x.owner.outputs[0]}

    X_rv, updates = aesara.scan(
        fn=step, outputs_info=[mu], non_sequences=[sigma], n_steps=num_timesteps
    )
    m.register_rv(X_rv, name="X_rv", data=data)
    # X_rv = pm.GaussianRandomWalk("X_rv", mu, sigma, observed=data)
    idata = pm.sample()