diff --git a/doc/whats_new/v0.2.rst b/doc/whats_new/v0.2.rst index bc7bcf9b..d08bc8a1 100644 --- a/doc/whats_new/v0.2.rst +++ b/doc/whats_new/v0.2.rst @@ -27,6 +27,7 @@ Changelog --------- - |Feature| Implement and test functions to validate an MAG and check the presence of almost directed cycles, by `Aryan Roy`_ (:pr:`91`) - |Feature| Implement and test functions to convert a DAG to MAG, by `Aryan Roy`_ (:pr:`96`) +- |Feature| Implement and test functions to convert a PAG to MAG, by `Aryan Roy`_ (:pr:`93`) Code and Documentation Contributors ----------------------------------- diff --git a/pywhy_graphs/algorithms/pag.py b/pywhy_graphs/algorithms/pag.py index 4f8f2a28..d6c69e45 100644 --- a/pywhy_graphs/algorithms/pag.py +++ b/pywhy_graphs/algorithms/pag.py @@ -1,12 +1,12 @@ import logging from collections import deque -from itertools import chain +from itertools import chain, combinations, permutations from typing import List, Optional, Set, Tuple import networkx as nx import numpy as np -from pywhy_graphs import PAG, StationaryTimeSeriesPAG +from pywhy_graphs import ADMG, CPDAG, PAG, StationaryTimeSeriesPAG from pywhy_graphs.algorithms.generic import single_source_shortest_mixed_path from pywhy_graphs.typing import Node, TsNode @@ -22,6 +22,7 @@ "pds_t", "pds_t_path", "is_definite_noncollider", + "pag_to_mag", ] @@ -908,3 +909,270 @@ def _check_ts_node(node): ) if node[1] > 0: raise ValueError(f"All lag points should be 0, or less. You passed in {node}.") + + +def _apply_meek_rules(graph: CPDAG) -> None: + """Orient edges in a skeleton graph to estimate the causal DAG, or CPDAG. + These are known as the Meek rules :footcite:`Meek1995`. They are deterministic + in the sense that they are logical characterizations of what edges must be + present given the rest of the local graph structure. + Parameters + ---------- + graph : CPDAG + A graph containing directed and undirected edges. + """ + # For all the combination of nodes i and j, apply the following + # rules. + completed = False + while not completed: # type: ignore + change_flag = False + for i in graph.nodes: + for j in graph.neighbors(i): + if i == j: + continue + # Rule 1: Orient i-j into i->j whenever there is an arrow k->i + # such that k and j are nonadjacent. + r1_add = _meek_rule1(graph, i, j) + + # Rule 2: Orient i-j into i->j whenever there is a chain + # i->k->j. + r2_add = _meek_rule2(graph, i, j) + + # Rule 3: Orient i-j into i->j whenever there are two chains + # i-k->j and i-l->j such that k and l are nonadjacent. + r3_add = _meek_rule3(graph, i, j) + + # Rule 4: Orient i-j into i->j whenever there are two chains + # i-k->l and k->l->j such that k and j are nonadjacent. + # + r4_add = _meek_rule4(graph, i, j) + + if any([r1_add, r2_add, r3_add, r4_add]) and not change_flag: + change_flag = True + if not change_flag: + completed = True + break + + +def _meek_rule1(graph: CPDAG, i: str, j: str) -> bool: + """Apply rule 1 of Meek's rules. + Looks for i - j such that k -> i, such that (k,i,j) + is an unshielded triple. Then can orient i - j as i -> j. + """ + added_arrows = False + + # Check if i-j. + if graph.has_edge(i, j, graph.undirected_edge_name): + for k in graph.predecessors(i): + # Skip if k and j are adjacent because then it is a + # shielded triple + if j in graph.neighbors(k): + continue + + # check if the triple is in the graph's excluded triples + if frozenset((k, i, j)) in graph.excluded_triples: + continue + + # Make i-j into i->j + graph.orient_uncertain_edge(i, j) + + added_arrows = True + break + return added_arrows + + +def _meek_rule2(graph: CPDAG, i: str, j: str) -> bool: + """Apply rule 2 of Meek's rules. + Check for i - j, and then looks for i -> k -> j + triple, to orient i - j as i -> j. + """ + added_arrows = False + + # Check if i-j. + if graph.has_edge(i, j, graph.undirected_edge_name): + # Find nodes k where k is i->k + child_i = set() + for k in graph.successors(i): + if not graph.has_edge(k, i, graph.directed_edge_name): + child_i.add(k) + # Find nodes j where j is k->j. + parent_j = set() + for k in graph.predecessors(j): + if not graph.has_edge(j, k, graph.directed_edge_name): + parent_j.add(k) + + # Check if there is any node k where i->k->j. + candidate_k = child_i.intersection(parent_j) + # if the graph has excluded triples, we would check at this point + if graph.excluded_triples: + # check if the triple is in the graph's excluded triples + # if so, remove them from the candidates + for k in candidate_k: + if frozenset((i, k, j)) in graph.excluded_triples: + candidate_k.remove(k) + + # if there are candidate 'k' nodes, then orient the edge accordingly + if len(candidate_k) > 0: + # Make i-j into i->j + graph.orient_uncertain_edge(i, j) + added_arrows = True + return added_arrows + + +def _meek_rule3(graph: CPDAG, i: str, j: str) -> bool: + """Apply rule 3 of Meek's rules. + Check for i - j, and then looks for k -> j <- l + collider, and i - k and i - l, then orient i -> j. + """ + added_arrows = False + + # Check if i-j first + if graph.has_edge(i, j, graph.undirected_edge_name): + # For all the pairs of nodes adjacent to i, + # look for (k, l), such that j -> l and k -> l + for (k, l) in combinations(graph.neighbors(i), 2): + # Skip if k and l are adjacent. + if l in graph.neighbors(k): + continue + # Skip if not k->j. + if graph.has_edge(j, k, graph.directed_edge_name) or ( + not graph.has_edge(k, j, graph.directed_edge_name) + ): + continue + # Skip if not l->j. + if graph.has_edge(j, l, graph.directed_edge_name) or ( + not graph.has_edge(l, j, graph.directed_edge_name) + ): + continue + + # check if the triple is inside graph's excluded triples + if frozenset((l, i, k)) in graph.excluded_triples: + continue + + # if i - k and i - l, then at this point, we have a valid path + # to orient + if graph.has_edge(k, i, graph.undirected_edge_name) and graph.has_edge( + l, i, graph.undirected_edge_name + ): + graph.orient_uncertain_edge(i, j) + added_arrows = True + break + return added_arrows + + +def _meek_rule4(graph: CPDAG, i: str, j: str) -> bool: + """Apply rule 4 of Meek's rules. + Check for i - j, and then looks for i - k -> l -> j, to orient i - j as i -> j. + """ + added_arrows = False + + # Check if i-j. + if graph.has_edge(i, j, graph.undirected_edge_name): + # Find nodes k where k is i-k + adj_i = set() + for k in graph.neighbors(i): + if not graph.has_edge(k, i, graph.directed_edge_name): + adj_i.add(k) + + # Find nodes l where j is l->j. + parent_j = set() + for k in graph.predecessors(j): + if not graph.has_edge(j, k, graph.directed_edge_name): + parent_j.add(k) + + # generate all permutations of sets containing neighbors of i and parents of j + permut = permutations(adj_i, len(parent_j)) + unq = set() # type: ignore + for comb in permut: + zipped = zip(comb, parent_j) + unq.update(zipped) + + # check if these pairs have a directed edge between them and that k-j does not exist + dedges = set(graph.directed_edges) + undedges = set(graph.undirected_edges) + candidate_k = set() + for pair in unq: + if pair in dedges: + if (pair[0], j) not in undedges: + candidate_k.add(pair) + + # if there are candidate 'k->l' pairs, then orient the edge accordingly + if len(candidate_k) > 0: + # Make i-j into i->j + # logger.info(f"R2: Removing edge {i}-{j} to form {i}->{j}.") + graph.orient_uncertain_edge(i, j) + added_arrows = True + return added_arrows + + +def pag_to_mag(graph): + """Sample a MAG from a PAG using Zhang's algorithm. + + Using the algorithm defined in Theorem 2 of :footcite:`Zhang2008`, which turns all + o-> edges to -> and -o edges to ->, then it converts the graph into a DAG with + no unshielded colliders using the meek rules. + + Parameters + ---------- + G : Graph + The PAG. + + Returns + ------- + mag : Graph + The MAG constructed from the PAG. + """ + copy_graph = graph.copy() + + cedges = set(copy_graph.circle_edges) + dedges = set(copy_graph.directed_edges) + + temp_cpdag = CPDAG() + + to_remove = [] + to_reorient = [] + to_add = [] + + for u, v in cedges: + if (v, u) in dedges: # remove the circle end from a 'o-->' edge to make a '-->' edge + to_remove.append((u, v)) + elif (v, u) not in cedges: # reorient a '--o' edge to '-->' + to_reorient.append((u, v)) + elif (v, u) in cedges and ( + v, + u, + ) not in to_add: # add all 'o--o' edges to the cpdag + to_add.append((u, v)) + for u, v in to_remove: + copy_graph.remove_edge(u, v, copy_graph.circle_edge_name) + for u, v in to_reorient: + copy_graph.orient_uncertain_edge(u, v) + for u, v in to_add: + temp_cpdag.add_edge(v, u, temp_cpdag.undirected_edge_name) + + flag = True + + # convert the graph into a DAG with no unshielded colliders + + while flag: + undedges = temp_cpdag.undirected_edges + if len(undedges) != 0: + for (u, v) in undedges: + temp_cpdag.remove_edge(u, v, temp_cpdag.undirected_edge_name) + temp_cpdag.add_edge(u, v, temp_cpdag.directed_edge_name) + _apply_meek_rules(temp_cpdag) + break + else: + flag = False + + mag = ADMG() # provisional MAG + + # construct the final MAG + + for (u, v) in copy_graph.directed_edges: + mag.add_edge(u, v, mag.directed_edge_name) + + for (u, v) in temp_cpdag.directed_edges: + mag.add_edge(u, v, mag.directed_edge_name) + + return mag diff --git a/pywhy_graphs/algorithms/tests/test_pag.py b/pywhy_graphs/algorithms/tests/test_pag.py index c09607a8..1f0deb53 100644 --- a/pywhy_graphs/algorithms/tests/test_pag.py +++ b/pywhy_graphs/algorithms/tests/test_pag.py @@ -647,3 +647,80 @@ def test_pdst(pdst_graph): ex_pdsep_t = pds_t(G, ("E", 0), ("x", -1)) assert ("y", -2) not in xe_pdsep_t assert ("y", -2) not in ex_pdsep_t + + +def test_pag_to_mag(): + + # C o- A o-> D <-o B + # B o-o A o-o C o-> D + + pag = PAG() + pag.add_edge("A", "D", pag.directed_edge_name) + pag.add_edge("A", "C", pag.circle_edge_name) + pag.add_edge("D", "A", pag.circle_edge_name) + pag.add_edge("B", "D", pag.directed_edge_name) + pag.add_edge("C", "D", pag.directed_edge_name) + pag.add_edge("D", "B", pag.circle_edge_name) + pag.add_edge("D", "C", pag.circle_edge_name) + pag.add_edge("C", "A", pag.circle_edge_name) + pag.add_edge("B", "A", pag.circle_edge_name) + pag.add_edge("A", "B", pag.circle_edge_name) + + out_mag = pywhy_graphs.pag_to_mag(pag) + + # C <- A -> B -> D or C -> A -> B -> D or C <- A <- B -> D + # A -> D <- C + + assert ( + ((out_mag.has_edge("A", "B")) or (out_mag.has_edge("B", "A"))) + and ((out_mag.has_edge("A", "C")) or (out_mag.has_edge("C", "A"))) + and (out_mag.has_edge("A", "D")) + and (out_mag.has_edge("B", "D")) + and (out_mag.has_edge("C", "D")) + ) + + # D o-> A <-o B + # D o-o B + pag = PAG() + pag.add_edge("A", "B", pag.circle_edge_name) + pag.add_edge("B", "A", pag.directed_edge_name) + pag.add_edge("D", "A", pag.directed_edge_name) + pag.add_edge("A", "D", pag.circle_edge_name) + pag.add_edge("D", "B", pag.circle_edge_name) + pag.add_edge("B", "D", pag.circle_edge_name) + + out_mag = pywhy_graphs.pag_to_mag(pag) + + # B -> A <- D + # D -> B or D <- B + + assert ( + out_mag.has_edge("B", "A") + and out_mag.has_edge("D", "A") + and (out_mag.has_edge("D", "B") or out_mag.has_edge("B", "D")) + ) + + # A -> B <- C o-o D + # D o-o E -> B + + pag = PAG() + pag.add_edge("A", "B", pag.directed_edge_name) + pag.add_edge("C", "B", pag.directed_edge_name) + pag.add_edge("E", "B", pag.directed_edge_name) + pag.add_edge("E", "D", pag.circle_edge_name) + pag.add_edge("C", "D", pag.circle_edge_name) + pag.add_edge("D", "E", pag.circle_edge_name) + pag.add_edge("D", "C", pag.circle_edge_name) + + out_mag = pywhy_graphs.pag_to_mag(pag) + + # A -> B <- C <- D or A -> B <- C -> D + # D <- E -> B or D <- E -> B + + assert ( + out_mag.has_edge("A", "B") + and out_mag.has_edge("C", "B") + and out_mag.has_edge("E", "B") + and (out_mag.has_edge("E", "D") or out_mag.has_edge("D", "E")) + and (out_mag.has_edge("D", "C") or out_mag.has_edge("C", "D")) + )