Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add the ability to check the validity of an MAG #91

Merged
merged 16 commits into from
Aug 17, 2023
Merged
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ causal graph operations.
.. autosummary::
:toctree: generated/

valid_mag
has_adc
inducing_path
is_valid_mec_graph
possible_ancestors
Expand Down
3 changes: 2 additions & 1 deletion doc/whats_new/v0.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Version 0.2

Changelog
---------
-
- |Feature| Implement and test functions to validate an MAG and check the presence of almost directed cycles, by `Aryan Roy`_ (:pr:`91`)

Code and Documentation Contributors
-----------------------------------
Expand All @@ -34,4 +34,5 @@ Thanks to everyone who has contributed to the maintenance and improvement of
the project since version inception, including:

* `Adam Li`_
* `Aryan Roy`_

99 changes: 99 additions & 0 deletions pywhy_graphs/algorithms/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
"set_nodes_as_latent_confounders",
"is_valid_mec_graph",
"inducing_path",
"has_adc",
"valid_mag",
]


Expand Down Expand Up @@ -604,3 +606,100 @@ def inducing_path(G, node_x: Node, node_y: Node, L: Set = None, S: Set = None):
break

return (path_exists, path)


def has_adc(G):
"""Check if a graph has an almost directed cycle (adc).

aryan26roy marked this conversation as resolved.
Show resolved Hide resolved
An almost directed cycle is a is a directed cycle containing
one bidirected edge. For example, ``A -> B -> C <-> A`` is an adc.

aryan26roy marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------
G : Graph
The graph.

Returns
-------
adc_present : bool
A boolean indicating whether an almost directed cycle is present or not.
"""

adc_present = False

biedges = G.bidirected_edges

for elem in G.nodes:
ancestors = nx.ancestors(G.sub_directed_graph(), elem)
descendants = nx.descendants(G.sub_directed_graph(), elem)
for elem in biedges:
if (elem[0] in ancestors and elem[1] in descendants) or (
elem[1] in ancestors and elem[0] in descendants
): # there is a bidirected edge from one of the ancestors to a descendant
return not adc_present

return adc_present


def valid_mag(G: ADMG, L: set = None, S: set = None):
aryan26roy marked this conversation as resolved.
Show resolved Hide resolved
"""Checks if the provided graph is a valid maximal ancestral graph (MAG).

A valid MAG as defined in :footcite:`Zhang2008` is a mixed edge graph that
only has directed and bi-directed edges, no directed or almost directed
cycles and no inducing paths between any two non-adjacent pair of nodes.

aryan26roy marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------
G : Graph
The graph.

Returns
-------
is_valid : bool
A boolean indicating whether the provided graph is a valid MAG or not.

"""

if L is None:
L = set()
aryan26roy marked this conversation as resolved.
Show resolved Hide resolved

if S is None:
S = set()
aryan26roy marked this conversation as resolved.
Show resolved Hide resolved

directed_sub_graph = G.sub_directed_graph()

all_nodes = set(G.nodes)

# check if there are any undirected edges or more than one edges b/w two nodes
for node in all_nodes:
nb = set(G.neighbors(node))
for elem in nb:
edge_data = G.get_edge_data(node, elem)
if edge_data["undirected"] is not None:
return False
elif (edge_data["bidirected"] is not None) and (edge_data["directed"] is not None):
return False

# check if there are any directed cyclces
try:
nx.find_cycle(directed_sub_graph) # raises a NetworkXNoCycle error
return False
except nx.NetworkXNoCycle:
pass

# check if there are any almost directed cycles
if has_adc(G): # if there is an ADC, it's not a valid MAG
return False

# check if there are any inducing paths between non-adjacent nodes

for source in all_nodes:
nb = set(G.neighbors(source))
cur_set = all_nodes - nb
cur_set.remove(source)
for dest in cur_set:
out = inducing_path(G, source, dest, L, S)
if out[0] is True:
return False

return True
120 changes: 120 additions & 0 deletions pywhy_graphs/algorithms/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,123 @@ def test_is_collider():
S = {"A"}

assert pywhy_graphs.inducing_path(admg, "Z", "Y", L, S)[0]


def test_has_adc():
# K -> H -> Z -> X -> Y -> J <- K
aryan26roy marked this conversation as resolved.
Show resolved Hide resolved
admg = ADMG()
admg.add_edge("Z", "X", admg.directed_edge_name)
admg.add_edge("X", "Y", admg.directed_edge_name)
admg.add_edge("Y", "J", admg.directed_edge_name)
admg.add_edge("H", "Z", admg.directed_edge_name)
admg.add_edge("K", "H", admg.directed_edge_name)
admg.add_edge("K", "J", admg.directed_edge_name)

assert not pywhy_graphs.has_adc(admg) # there is no cycle completed by a bidirected edge

# K -> H -> Z -> X -> Y -> J <-> K
admg = ADMG()
admg.add_edge("Z", "X", admg.directed_edge_name)
admg.add_edge("X", "Y", admg.directed_edge_name)
admg.add_edge("Y", "J", admg.directed_edge_name)
admg.add_edge("H", "Z", admg.directed_edge_name)
admg.add_edge("K", "H", admg.directed_edge_name)
admg.add_edge("Y", "J", admg.directed_edge_name)
admg.add_edge("K", "J", admg.bidirected_edge_name)

assert pywhy_graphs.has_adc(admg) # there is a bidirected edge from J to K, completing a cycle

# K -> H -> Z -> X -> Y <- J <-> K
admg = ADMG()
admg.add_edge("Z", "X", admg.directed_edge_name)
admg.add_edge("X", "Y", admg.directed_edge_name)
admg.add_edge("J", "Y", admg.directed_edge_name)
admg.add_edge("H", "Z", admg.directed_edge_name)
admg.add_edge("K", "H", admg.directed_edge_name)
admg.add_edge("K", "J", admg.bidirected_edge_name)

assert not pywhy_graphs.has_adc(admg) # Y <- J is not correctly oriented

# I -> H -> Z -> X -> Y -> J <-> K
# J -> I
admg = ADMG()
admg.add_edge("Z", "X", admg.directed_edge_name)
admg.add_edge("X", "Y", admg.directed_edge_name)
admg.add_edge("Y", "J", admg.directed_edge_name)
admg.add_edge("H", "Z", admg.directed_edge_name)
admg.add_edge("K", "H", admg.directed_edge_name)
admg.add_edge("Y", "H", admg.directed_edge_name)
admg.add_edge("K", "J", admg.bidirected_edge_name)

assert pywhy_graphs.has_adc(admg) # J <-> K completes an otherwise directed cycle


def test_valid_mag():
# K -> H -> Z -> X -> Y -> J <- K
aryan26roy marked this conversation as resolved.
Show resolved Hide resolved
admg = ADMG()
admg.add_edge("Z", "X", admg.directed_edge_name)
admg.add_edge("X", "Y", admg.directed_edge_name)
admg.add_edge("Y", "J", admg.directed_edge_name)
admg.add_edge("H", "Z", admg.directed_edge_name)
admg.add_edge("K", "H", admg.directed_edge_name)
admg.add_edge("K", "J", admg.directed_edge_name)

S = {"J"}
L = {}

assert not pywhy_graphs.valid_mag(
admg, L, S # J is in S and is a collider on the path Y -> J <- K
)

S = {}

assert pywhy_graphs.valid_mag(admg, L, S) # there are no valid inducing paths

# K -> H -> Z -> X -> Y -> J -> K
admg = ADMG()
admg.add_edge("Z", "X", admg.directed_edge_name)
admg.add_edge("X", "Y", admg.directed_edge_name)
admg.add_edge("Y", "J", admg.directed_edge_name)
admg.add_edge("H", "Z", admg.directed_edge_name)
admg.add_edge("K", "H", admg.directed_edge_name)
admg.add_edge("J", "K", admg.directed_edge_name)

L = {}

assert not pywhy_graphs.valid_mag(admg, L, S) # there is a directed cycle

# K -> H -> Z -> X -> Y -> J <- K
# H <-> J
admg = ADMG()
admg.add_edge("Z", "X", admg.directed_edge_name)
admg.add_edge("X", "Y", admg.directed_edge_name)
admg.add_edge("Y", "J", admg.directed_edge_name)
admg.add_edge("H", "Z", admg.directed_edge_name)
admg.add_edge("K", "H", admg.directed_edge_name)
admg.add_edge("K", "J", admg.directed_edge_name)
admg.add_edge("H", "J", admg.bidirected_edge_name)

assert not pywhy_graphs.valid_mag(admg) # there is an almost directed cycle

admg = ADMG()
admg.add_edge("Z", "X", admg.directed_edge_name)
admg.add_edge("X", "Y", admg.directed_edge_name)
admg.add_edge("Y", "J", admg.directed_edge_name)
admg.add_edge("H", "Z", admg.directed_edge_name)
admg.add_edge("K", "H", admg.directed_edge_name)
admg.add_edge("K", "J", admg.directed_edge_name)
admg.add_edge("H", "J", admg.bidirected_edge_name)
admg.add_edge("H", "J", admg.directed_edge_name)

assert not pywhy_graphs.valid_mag(admg) # there are two edges between H and J

admg = ADMG()
admg.add_edge("Z", "X", admg.directed_edge_name)
admg.add_edge("X", "Y", admg.directed_edge_name)
admg.add_edge("Y", "J", admg.directed_edge_name)
admg.add_edge("H", "Z", admg.directed_edge_name)
admg.add_edge("K", "H", admg.directed_edge_name)
admg.add_edge("K", "J", admg.directed_edge_name)
admg.add_edge("H", "J", admg.undirected_edge_name)

assert not pywhy_graphs.valid_mag(admg) # there is an undirected edge between H and J