Allow the use of `stack`ed composite terms
brandonwillard opened this issue · 4 comments
We should make it possible to use random/stochastic composite terms constructed using function like aesara.tensor.stack
(e.g. the Join
Op
).
For example, the following should work:
import aesara
import aesara.tensor as at
from aeppl.joint_logprob import factorized_joint_logprob
M = 2
srng = at.random.RandomStream(seed=2023532)
p_0_rv = srng.dirichlet(at.ones(M), name="p_0")
p_1_rv = srng.dirichlet(at.ones(M), name="p_1")
Gamma_rv = at.stack([p_0_rv, p_1_rv])
aesara.dprint(Gamma_rv)
# Join [id A] ''
# |TensorConstant{0} [id B]
# |InplaceDimShuffle{x,0} [id C] ''
# | |dirichlet_rv{1, (1,), floatX, False}.1 [id D] 'p_0'
# | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FFA4B987050>) [id E]
# | |TensorConstant{[]} [id F]
# | |TensorConstant{11} [id G]
# | |Alloc [id H] ''
# | |TensorConstant{1.0} [id I]
# | |TensorConstant{2} [id J]
# |InplaceDimShuffle{x,0} [id K] ''
# |dirichlet_rv{1, (1,), floatX, False}.1 [id L] 'p_1'
# |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FFA4B987450>) [id M]
# |TensorConstant{[]} [id N]
# |TensorConstant{11} [id O]
# |Alloc [id P] ''
# |TensorConstant{1.0} [id Q]
# |TensorConstant{2} [id R]
gamma_vv = Gamma_rv.clone()
gamma_vv.name = "gamma"
# Currently does not work
factorized_joint_logprob({Gamma_rv: gamma_vv})
Currently, we are required to work in terms of p_0_rv
and p_1_rv
, which can be quite inconvenient.
This is very related to our existing mixture support, so it shouldn't be particularly difficult to add.
I explored this a bit. The biggest challenge I found was how to keep the size information when replacing the RVs with value variables, for instance when stacking [random.normal(0, size=2), random.normal(1, size=3)]]
, so that we can later dispatch the right subtensor to each logp
It should be much more straightforward after we implement the ValuedVariables
Actually what I tried back then was chop off the value variable during the rewrite phase and assign each part to the respective RandomVariables, but this wouldn't work now that we return a dictionary with value - logp pair.
I think we would just need to create a "MeasurableStack" that will keep having the RVs as inputs so retrieving size info wouldn't be a problem.
I think we would just need to create a "MeasurableStack" that will keep having the RVs as inputs so retrieving size info wouldn't be a problem.
If you mean something like aeppl.mixture.MixtureRV
, but for Join
ed terms, then, yes, that should work just fine.
These changes will require a refactoring of our approach to handling mixtures, though, because—with such an intermediate representation—our mixtures would necessarily be functions of those forms (i.e. mixtures would become composite stacked measurable variables indexed by measurable variables).
I was thinking in isolation. For that we still need to get the nested rewrites strategy sorted out