aesara-devs/aeppl

Scan RVs cannot be composed with other measurable operations

ricardoV94 opened this issue · 0 comments

Scan rewrite seems to require a value variable directly associated with it. For this reason the following simple case fails:

import aesara, aesara.tensor as at
from aeppl import factorized_joint_logprob

x, x_updates = aesara.scan(
    fn=lambda: at.random.normal(0, 1),
    n_steps=10,
)
x.name = "x"
x_vv = x.clone()
assert factorized_joint_logprob({x: x_vv})  # Fine

sx = x + 5
sx.name = "sx"
sx_vv = sx.clone()
assert factorized_joint_logprob({sx: sx_vv})  # AssertionError

This was brought up in pymc-devs/pymc#6119, where a concatenated graph would be the goal.