Double `Unmeasurable` prefix in components of mixture sub-graphs
larryshamalama opened this issue · 2 comments
Mixture sub-graphs are identified by mixture_replace and each component is assigned to be unmeasurable as per here. However, in defining the stack, tensors which are combined go through a similar procedure here, hence why the type UnmeasurableUnmeasurableNormalRV
occurs. This becomes an issue when trying to use functions such as aesara.graph.basic.equal_computations
to check if Op types are identical.
Below is an example of this error:
import aesara.tensor as at
from aeppl.opt import construct_ir_fgraph
srng = at.random.RandomStream(seed=2320)
I_rv = srng.bernoulli(0.5, name="I")
X_rv = srng.normal(-10., 0.1, name="X")
Y_rv = srng.normal(10., 0.1, name="Y")
# mixture from stack
Z_rv = at.stack([X_rv, Y_rv])[I_rv]
Z_rv.name = "Z"
z_vv = Z_rv.clone()
i_vv = I_rv.clone()
fgraph, _, _ = construct_ir_fgraph({Z_rv: z_vv, I_rv: i_vv})
fgraph.outputs[0] # Z-mixture
fgraph.outputs[0].owner.inputs # [NoneConst, ScalarFromTensor.0, X, Y]
type(fgraph.outputs[0].owner.inputs[2].owner.op) # aesara.tensor.random.basic.UnmeasurableUnmeasurableNormalRV
I'm not sure what is the preferred approach to address this. Is the problem outlined above inherent to the assignment of nodes to be Unmeasurable
or general to the assign_custom_measurable_outputs
function? In other words, below are some ideas:
- Check in
mixture_replace
if components are already deemedUnmeasurable
; - Check if
op_type.__name__.startswith(type_prefix)
and return the originalnode
(or evennew_node
but without the prefix inserted?).
It might also be possible to return the original node in the call to assign_custom_measurable_outputs
when the Op
is already assigned the given measurable_outputs_fn
in the _get_measurable_outputs
dispatch.