aesara-devs/aeppl

Double `Unmeasurable` prefix in components of mixture sub-graphs

larryshamalama opened this issue · 2 comments

Mixture sub-graphs are identified by mixture_replace and each component is assigned to be unmeasurable as per here. However, in defining the stack, tensors which are combined go through a similar procedure here, hence why the type UnmeasurableUnmeasurableNormalRV occurs. This becomes an issue when trying to use functions such as aesara.graph.basic.equal_computations to check if Op types are identical.

Below is an example of this error:

import aesara.tensor as at
from aeppl.opt import construct_ir_fgraph

srng = at.random.RandomStream(seed=2320)

I_rv = srng.bernoulli(0.5, name="I")
X_rv = srng.normal(-10., 0.1, name="X")
Y_rv = srng.normal(10., 0.1, name="Y")

# mixture from stack
Z_rv = at.stack([X_rv, Y_rv])[I_rv]
Z_rv.name = "Z"
z_vv = Z_rv.clone()

i_vv = I_rv.clone()

fgraph, _, _ = construct_ir_fgraph({Z_rv: z_vv, I_rv: i_vv})

fgraph.outputs[0] # Z-mixture
fgraph.outputs[0].owner.inputs # [NoneConst, ScalarFromTensor.0, X, Y]
type(fgraph.outputs[0].owner.inputs[2].owner.op) # aesara.tensor.random.basic.UnmeasurableUnmeasurableNormalRV

I'm not sure what is the preferred approach to address this. Is the problem outlined above inherent to the assignment of nodes to be Unmeasurable or general to the assign_custom_measurable_outputs function? In other words, below are some ideas:

  • Check in mixture_replace if components are already deemed Unmeasurable;
  • Check if op_type.__name__.startswith(type_prefix) and return the original node (or even new_node but without the prefix inserted?).

It might also be possible to return the original node in the call to assign_custom_measurable_outputs when the Op is already assigned the given measurable_outputs_fn in the _get_measurable_outputs dispatch.