Scan RVs cannot be composed with other measurable operations
ricardoV94 opened this issue · 0 comments
ricardoV94 commented
Scan rewrite seems to require a value variable directly associated with it. For this reason the following simple case fails:
import aesara, aesara.tensor as at
from aeppl import factorized_joint_logprob
x, x_updates = aesara.scan(
fn=lambda: at.random.normal(0, 1),
n_steps=10,
)
x.name = "x"
x_vv = x.clone()
assert factorized_joint_logprob({x: x_vv}) # Fine
sx = x + 5
sx.name = "sx"
sx_vv = sx.clone()
assert factorized_joint_logprob({sx: sx_vv}) # AssertionError
This was brought up in pymc-devs/pymc#6119, where a concatenated graph would be the goal.