Skip to content

Commit

Permalink
[ENH] Add the ability to check the validity of an MAG (#91)
Browse files Browse the repository at this point in the history
* Added inducing path checking to MAG check
* made find_adc public and added tests

---------

Signed-off-by: “Aryan <“aryanroy5678@gmail.com”>
Co-authored-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
aryan26roy and adam2392 authored Aug 17, 2023
1 parent 013513b commit 96e58b9
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 1 deletion.
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ causal graph operations.
.. autosummary::
:toctree: generated/

valid_mag
has_adc
inducing_path
is_valid_mec_graph
possible_ancestors
Expand Down
3 changes: 2 additions & 1 deletion doc/whats_new/v0.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Version 0.2

Changelog
---------
-
- |Feature| Implement and test functions to validate an MAG and check the presence of almost directed cycles, by `Aryan Roy`_ (:pr:`91`)

Code and Documentation Contributors
-----------------------------------
Expand All @@ -34,4 +34,5 @@ Thanks to everyone who has contributed to the maintenance and improvement of
the project since version inception, including:

* `Adam Li`_
* `Aryan Roy`_

99 changes: 99 additions & 0 deletions pywhy_graphs/algorithms/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
"set_nodes_as_latent_confounders",
"is_valid_mec_graph",
"inducing_path",
"has_adc",
"valid_mag",
]


Expand Down Expand Up @@ -604,3 +606,100 @@ def inducing_path(G, node_x: Node, node_y: Node, L: Set = None, S: Set = None):
break

return (path_exists, path)


def has_adc(G):
"""Check if a graph has an almost directed cycle (adc).
An almost directed cycle is a is a directed cycle containing
one bidirected edge. For example, ``A -> B -> C <-> A`` is an adc.
Parameters
----------
G : Graph
The graph.
Returns
-------
adc_present : bool
A boolean indicating whether an almost directed cycle is present or not.
"""

adc_present = False

biedges = G.bidirected_edges

for elem in G.nodes:
ancestors = nx.ancestors(G.sub_directed_graph(), elem)
descendants = nx.descendants(G.sub_directed_graph(), elem)
for elem in biedges:
if (elem[0] in ancestors and elem[1] in descendants) or (
elem[1] in ancestors and elem[0] in descendants
): # there is a bidirected edge from one of the ancestors to a descendant
return not adc_present

return adc_present


def valid_mag(G: ADMG, L: set = None, S: set = None):
"""Checks if the provided graph is a valid maximal ancestral graph (MAG).
A valid MAG as defined in :footcite:`Zhang2008` is a mixed edge graph that
only has directed and bi-directed edges, no directed or almost directed
cycles and no inducing paths between any two non-adjacent pair of nodes.
Parameters
----------
G : Graph
The graph.
Returns
-------
is_valid : bool
A boolean indicating whether the provided graph is a valid MAG or not.
"""

if L is None:
L = set()

if S is None:
S = set()

directed_sub_graph = G.sub_directed_graph()

all_nodes = set(G.nodes)

# check if there are any undirected edges or more than one edges b/w two nodes
for node in all_nodes:
nb = set(G.neighbors(node))
for elem in nb:
edge_data = G.get_edge_data(node, elem)
if edge_data["undirected"] is not None:
return False
elif (edge_data["bidirected"] is not None) and (edge_data["directed"] is not None):
return False

# check if there are any directed cyclces
try:
nx.find_cycle(directed_sub_graph) # raises a NetworkXNoCycle error
return False
except nx.NetworkXNoCycle:
pass

# check if there are any almost directed cycles
if has_adc(G): # if there is an ADC, it's not a valid MAG
return False

# check if there are any inducing paths between non-adjacent nodes

for source in all_nodes:
nb = set(G.neighbors(source))
cur_set = all_nodes - nb
cur_set.remove(source)
for dest in cur_set:
out = inducing_path(G, source, dest, L, S)
if out[0] is True:
return False

return True
120 changes: 120 additions & 0 deletions pywhy_graphs/algorithms/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,123 @@ def test_is_collider():
S = {"A"}

assert pywhy_graphs.inducing_path(admg, "Z", "Y", L, S)[0]


def test_has_adc():
# K -> H -> Z -> X -> Y -> J <- K
admg = ADMG()
admg.add_edge("Z", "X", admg.directed_edge_name)
admg.add_edge("X", "Y", admg.directed_edge_name)
admg.add_edge("Y", "J", admg.directed_edge_name)
admg.add_edge("H", "Z", admg.directed_edge_name)
admg.add_edge("K", "H", admg.directed_edge_name)
admg.add_edge("K", "J", admg.directed_edge_name)

assert not pywhy_graphs.has_adc(admg) # there is no cycle completed by a bidirected edge

# K -> H -> Z -> X -> Y -> J <-> K
admg = ADMG()
admg.add_edge("Z", "X", admg.directed_edge_name)
admg.add_edge("X", "Y", admg.directed_edge_name)
admg.add_edge("Y", "J", admg.directed_edge_name)
admg.add_edge("H", "Z", admg.directed_edge_name)
admg.add_edge("K", "H", admg.directed_edge_name)
admg.add_edge("Y", "J", admg.directed_edge_name)
admg.add_edge("K", "J", admg.bidirected_edge_name)

assert pywhy_graphs.has_adc(admg) # there is a bidirected edge from J to K, completing a cycle

# K -> H -> Z -> X -> Y <- J <-> K
admg = ADMG()
admg.add_edge("Z", "X", admg.directed_edge_name)
admg.add_edge("X", "Y", admg.directed_edge_name)
admg.add_edge("J", "Y", admg.directed_edge_name)
admg.add_edge("H", "Z", admg.directed_edge_name)
admg.add_edge("K", "H", admg.directed_edge_name)
admg.add_edge("K", "J", admg.bidirected_edge_name)

assert not pywhy_graphs.has_adc(admg) # Y <- J is not correctly oriented

# I -> H -> Z -> X -> Y -> J <-> K
# J -> I
admg = ADMG()
admg.add_edge("Z", "X", admg.directed_edge_name)
admg.add_edge("X", "Y", admg.directed_edge_name)
admg.add_edge("Y", "J", admg.directed_edge_name)
admg.add_edge("H", "Z", admg.directed_edge_name)
admg.add_edge("K", "H", admg.directed_edge_name)
admg.add_edge("Y", "H", admg.directed_edge_name)
admg.add_edge("K", "J", admg.bidirected_edge_name)

assert pywhy_graphs.has_adc(admg) # J <-> K completes an otherwise directed cycle


def test_valid_mag():
# K -> H -> Z -> X -> Y -> J <- K
admg = ADMG()
admg.add_edge("Z", "X", admg.directed_edge_name)
admg.add_edge("X", "Y", admg.directed_edge_name)
admg.add_edge("Y", "J", admg.directed_edge_name)
admg.add_edge("H", "Z", admg.directed_edge_name)
admg.add_edge("K", "H", admg.directed_edge_name)
admg.add_edge("K", "J", admg.directed_edge_name)

S = {"J"}
L = {}

assert not pywhy_graphs.valid_mag(
admg, L, S # J is in S and is a collider on the path Y -> J <- K
)

S = {}

assert pywhy_graphs.valid_mag(admg, L, S) # there are no valid inducing paths

# K -> H -> Z -> X -> Y -> J -> K
admg = ADMG()
admg.add_edge("Z", "X", admg.directed_edge_name)
admg.add_edge("X", "Y", admg.directed_edge_name)
admg.add_edge("Y", "J", admg.directed_edge_name)
admg.add_edge("H", "Z", admg.directed_edge_name)
admg.add_edge("K", "H", admg.directed_edge_name)
admg.add_edge("J", "K", admg.directed_edge_name)

L = {}

assert not pywhy_graphs.valid_mag(admg, L, S) # there is a directed cycle

# K -> H -> Z -> X -> Y -> J <- K
# H <-> J
admg = ADMG()
admg.add_edge("Z", "X", admg.directed_edge_name)
admg.add_edge("X", "Y", admg.directed_edge_name)
admg.add_edge("Y", "J", admg.directed_edge_name)
admg.add_edge("H", "Z", admg.directed_edge_name)
admg.add_edge("K", "H", admg.directed_edge_name)
admg.add_edge("K", "J", admg.directed_edge_name)
admg.add_edge("H", "J", admg.bidirected_edge_name)

assert not pywhy_graphs.valid_mag(admg) # there is an almost directed cycle

admg = ADMG()
admg.add_edge("Z", "X", admg.directed_edge_name)
admg.add_edge("X", "Y", admg.directed_edge_name)
admg.add_edge("Y", "J", admg.directed_edge_name)
admg.add_edge("H", "Z", admg.directed_edge_name)
admg.add_edge("K", "H", admg.directed_edge_name)
admg.add_edge("K", "J", admg.directed_edge_name)
admg.add_edge("H", "J", admg.bidirected_edge_name)
admg.add_edge("H", "J", admg.directed_edge_name)

assert not pywhy_graphs.valid_mag(admg) # there are two edges between H and J

admg = ADMG()
admg.add_edge("Z", "X", admg.directed_edge_name)
admg.add_edge("X", "Y", admg.directed_edge_name)
admg.add_edge("Y", "J", admg.directed_edge_name)
admg.add_edge("H", "Z", admg.directed_edge_name)
admg.add_edge("K", "H", admg.directed_edge_name)
admg.add_edge("K", "J", admg.directed_edge_name)
admg.add_edge("H", "J", admg.undirected_edge_name)

assert not pywhy_graphs.valid_mag(admg) # there is an undirected edge between H and J

0 comments on commit 96e58b9

Please sign in to comment.