diff --git a/doc/api.rst b/doc/api.rst index 990414cf..dfa8b06d 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -43,6 +43,8 @@ causal graph operations. .. autosummary:: :toctree: generated/ + valid_mag + has_adc inducing_path is_valid_mec_graph possible_ancestors diff --git a/doc/whats_new/v0.2.rst b/doc/whats_new/v0.2.rst index 3c7bad42..2bfec356 100644 --- a/doc/whats_new/v0.2.rst +++ b/doc/whats_new/v0.2.rst @@ -25,7 +25,7 @@ Version 0.2 Changelog --------- -- +- |Feature| Implement and test functions to validate an MAG and check the presence of almost directed cycles, by `Aryan Roy`_ (:pr:`91`) Code and Documentation Contributors ----------------------------------- @@ -34,4 +34,5 @@ Thanks to everyone who has contributed to the maintenance and improvement of the project since version inception, including: * `Adam Li`_ +* `Aryan Roy`_ diff --git a/pywhy_graphs/algorithms/generic.py b/pywhy_graphs/algorithms/generic.py index 6df21009..27d41211 100644 --- a/pywhy_graphs/algorithms/generic.py +++ b/pywhy_graphs/algorithms/generic.py @@ -13,6 +13,8 @@ "set_nodes_as_latent_confounders", "is_valid_mec_graph", "inducing_path", + "has_adc", + "valid_mag", ] @@ -604,3 +606,100 @@ def inducing_path(G, node_x: Node, node_y: Node, L: Set = None, S: Set = None): break return (path_exists, path) + + +def has_adc(G): + """Check if a graph has an almost directed cycle (adc). + + An almost directed cycle is a is a directed cycle containing + one bidirected edge. For example, ``A -> B -> C <-> A`` is an adc. + + Parameters + ---------- + G : Graph + The graph. + + Returns + ------- + adc_present : bool + A boolean indicating whether an almost directed cycle is present or not. + """ + + adc_present = False + + biedges = G.bidirected_edges + + for elem in G.nodes: + ancestors = nx.ancestors(G.sub_directed_graph(), elem) + descendants = nx.descendants(G.sub_directed_graph(), elem) + for elem in biedges: + if (elem[0] in ancestors and elem[1] in descendants) or ( + elem[1] in ancestors and elem[0] in descendants + ): # there is a bidirected edge from one of the ancestors to a descendant + return not adc_present + + return adc_present + + +def valid_mag(G: ADMG, L: set = None, S: set = None): + """Checks if the provided graph is a valid maximal ancestral graph (MAG). + + A valid MAG as defined in :footcite:`Zhang2008` is a mixed edge graph that + only has directed and bi-directed edges, no directed or almost directed + cycles and no inducing paths between any two non-adjacent pair of nodes. + + Parameters + ---------- + G : Graph + The graph. + + Returns + ------- + is_valid : bool + A boolean indicating whether the provided graph is a valid MAG or not. + + """ + + if L is None: + L = set() + + if S is None: + S = set() + + directed_sub_graph = G.sub_directed_graph() + + all_nodes = set(G.nodes) + + # check if there are any undirected edges or more than one edges b/w two nodes + for node in all_nodes: + nb = set(G.neighbors(node)) + for elem in nb: + edge_data = G.get_edge_data(node, elem) + if edge_data["undirected"] is not None: + return False + elif (edge_data["bidirected"] is not None) and (edge_data["directed"] is not None): + return False + + # check if there are any directed cyclces + try: + nx.find_cycle(directed_sub_graph) # raises a NetworkXNoCycle error + return False + except nx.NetworkXNoCycle: + pass + + # check if there are any almost directed cycles + if has_adc(G): # if there is an ADC, it's not a valid MAG + return False + + # check if there are any inducing paths between non-adjacent nodes + + for source in all_nodes: + nb = set(G.neighbors(source)) + cur_set = all_nodes - nb + cur_set.remove(source) + for dest in cur_set: + out = inducing_path(G, source, dest, L, S) + if out[0] is True: + return False + + return True diff --git a/pywhy_graphs/algorithms/tests/test_generic.py b/pywhy_graphs/algorithms/tests/test_generic.py index d0a52981..0f717893 100644 --- a/pywhy_graphs/algorithms/tests/test_generic.py +++ b/pywhy_graphs/algorithms/tests/test_generic.py @@ -228,3 +228,123 @@ def test_is_collider(): S = {"A"} assert pywhy_graphs.inducing_path(admg, "Z", "Y", L, S)[0] + + +def test_has_adc(): + # K -> H -> Z -> X -> Y -> J <- K + admg = ADMG() + admg.add_edge("Z", "X", admg.directed_edge_name) + admg.add_edge("X", "Y", admg.directed_edge_name) + admg.add_edge("Y", "J", admg.directed_edge_name) + admg.add_edge("H", "Z", admg.directed_edge_name) + admg.add_edge("K", "H", admg.directed_edge_name) + admg.add_edge("K", "J", admg.directed_edge_name) + + assert not pywhy_graphs.has_adc(admg) # there is no cycle completed by a bidirected edge + + # K -> H -> Z -> X -> Y -> J <-> K + admg = ADMG() + admg.add_edge("Z", "X", admg.directed_edge_name) + admg.add_edge("X", "Y", admg.directed_edge_name) + admg.add_edge("Y", "J", admg.directed_edge_name) + admg.add_edge("H", "Z", admg.directed_edge_name) + admg.add_edge("K", "H", admg.directed_edge_name) + admg.add_edge("Y", "J", admg.directed_edge_name) + admg.add_edge("K", "J", admg.bidirected_edge_name) + + assert pywhy_graphs.has_adc(admg) # there is a bidirected edge from J to K, completing a cycle + + # K -> H -> Z -> X -> Y <- J <-> K + admg = ADMG() + admg.add_edge("Z", "X", admg.directed_edge_name) + admg.add_edge("X", "Y", admg.directed_edge_name) + admg.add_edge("J", "Y", admg.directed_edge_name) + admg.add_edge("H", "Z", admg.directed_edge_name) + admg.add_edge("K", "H", admg.directed_edge_name) + admg.add_edge("K", "J", admg.bidirected_edge_name) + + assert not pywhy_graphs.has_adc(admg) # Y <- J is not correctly oriented + + # I -> H -> Z -> X -> Y -> J <-> K + # J -> I + admg = ADMG() + admg.add_edge("Z", "X", admg.directed_edge_name) + admg.add_edge("X", "Y", admg.directed_edge_name) + admg.add_edge("Y", "J", admg.directed_edge_name) + admg.add_edge("H", "Z", admg.directed_edge_name) + admg.add_edge("K", "H", admg.directed_edge_name) + admg.add_edge("Y", "H", admg.directed_edge_name) + admg.add_edge("K", "J", admg.bidirected_edge_name) + + assert pywhy_graphs.has_adc(admg) # J <-> K completes an otherwise directed cycle + + +def test_valid_mag(): + # K -> H -> Z -> X -> Y -> J <- K + admg = ADMG() + admg.add_edge("Z", "X", admg.directed_edge_name) + admg.add_edge("X", "Y", admg.directed_edge_name) + admg.add_edge("Y", "J", admg.directed_edge_name) + admg.add_edge("H", "Z", admg.directed_edge_name) + admg.add_edge("K", "H", admg.directed_edge_name) + admg.add_edge("K", "J", admg.directed_edge_name) + + S = {"J"} + L = {} + + assert not pywhy_graphs.valid_mag( + admg, L, S # J is in S and is a collider on the path Y -> J <- K + ) + + S = {} + + assert pywhy_graphs.valid_mag(admg, L, S) # there are no valid inducing paths + + # K -> H -> Z -> X -> Y -> J -> K + admg = ADMG() + admg.add_edge("Z", "X", admg.directed_edge_name) + admg.add_edge("X", "Y", admg.directed_edge_name) + admg.add_edge("Y", "J", admg.directed_edge_name) + admg.add_edge("H", "Z", admg.directed_edge_name) + admg.add_edge("K", "H", admg.directed_edge_name) + admg.add_edge("J", "K", admg.directed_edge_name) + + L = {} + + assert not pywhy_graphs.valid_mag(admg, L, S) # there is a directed cycle + + # K -> H -> Z -> X -> Y -> J <- K + # H <-> J + admg = ADMG() + admg.add_edge("Z", "X", admg.directed_edge_name) + admg.add_edge("X", "Y", admg.directed_edge_name) + admg.add_edge("Y", "J", admg.directed_edge_name) + admg.add_edge("H", "Z", admg.directed_edge_name) + admg.add_edge("K", "H", admg.directed_edge_name) + admg.add_edge("K", "J", admg.directed_edge_name) + admg.add_edge("H", "J", admg.bidirected_edge_name) + + assert not pywhy_graphs.valid_mag(admg) # there is an almost directed cycle + + admg = ADMG() + admg.add_edge("Z", "X", admg.directed_edge_name) + admg.add_edge("X", "Y", admg.directed_edge_name) + admg.add_edge("Y", "J", admg.directed_edge_name) + admg.add_edge("H", "Z", admg.directed_edge_name) + admg.add_edge("K", "H", admg.directed_edge_name) + admg.add_edge("K", "J", admg.directed_edge_name) + admg.add_edge("H", "J", admg.bidirected_edge_name) + admg.add_edge("H", "J", admg.directed_edge_name) + + assert not pywhy_graphs.valid_mag(admg) # there are two edges between H and J + + admg = ADMG() + admg.add_edge("Z", "X", admg.directed_edge_name) + admg.add_edge("X", "Y", admg.directed_edge_name) + admg.add_edge("Y", "J", admg.directed_edge_name) + admg.add_edge("H", "Z", admg.directed_edge_name) + admg.add_edge("K", "H", admg.directed_edge_name) + admg.add_edge("K", "J", admg.directed_edge_name) + admg.add_edge("H", "J", admg.undirected_edge_name) + + assert not pywhy_graphs.valid_mag(admg) # there is an undirected edge between H and J