y0-causal-inference/y0

Variables in NxMixedGraph should not care about the values they obtain

djinnome opened this issue · 0 comments

graph = NxMixedGraph.from_edges(directed=[(D, Z)])
graph.ancestors_inclusive(Variable(name='D', star=None))

returns:

{D}

However,

graph = NxMixedGraph.from_edges(directed=[(D, Z)])
graph.ancestors_inclusive(Variable(name='D', star=False))

returns:

----------------------------------------------------------------------------
NetworkXError                             Traceback (most recent call last)
<ipython-input-25-daab0c150eab> in <module>
      1 graph = NxMixedGraph.from_edges(directed=[(D, Z)])
      2 query=Variable(name='D', star=False)
----> 3 graph.ancestors_inclusive(query)

~/Projects/CausalInference/y0-causal-inference/y0/src/y0/graph.py in ancestors_inclusive(self, sources)
    397         """Ancestors of a set include the set itself."""
    398         sources = _ensure_set(sources)
--> 399         return _ancestors_inclusive(self.directed, sources)
    400 
    401     def topological_sort(self) -> Iterable[Variable]:

~/Projects/CausalInference/y0-causal-inference/y0/src/y0/graph.py in _ancestors_inclusive(graph, sources)
    417 
    418 def _ancestors_inclusive(graph: nx.DiGraph, sources: set[Variable]) -> set[Variable]:
--> 419     ancestors = set(
    420         itt.chain.from_iterable(nx.algorithms.dag.ancestors(graph, source) for source in sources)
    421     )

~/Projects/CausalInference/y0-causal-inference/y0/src/y0/graph.py in <genexpr>(.0)
    418 def _ancestors_inclusive(graph: nx.DiGraph, sources: set[Variable]) -> set[Variable]:
    419     ancestors = set(
--> 420         itt.chain.from_iterable(nx.algorithms.dag.ancestors(graph, source) for source in sources)
    421     )
    422     return sources | ancestors

~/.pyenv/versions/anaconda3-2020.11/lib/python3.8/site-packages/networkx/algorithms/dag.py in ancestors(G, source)
     90     """
     91     if not G.has_node(source):
---> 92         raise nx.NetworkXError("The node %s is not in the graph." % source)
     93     anc = set(n for n, d in nx.shortest_path_length(G, target=source).items())
     94     return anc - {source}

NetworkXError: The node D is not in the graph.