From 68b98943ce437cd4153a3d865a9a6ab335a3f91a Mon Sep 17 00:00:00 2001 From: Aryan Roy Date: Mon, 17 Jul 2023 12:44:26 +0530 Subject: [PATCH 01/22] Added function for converting pag to mag Signed-off-by: Aryan Roy --- pywhy_graphs/algorithms/pag.py | 48 +++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/pywhy_graphs/algorithms/pag.py b/pywhy_graphs/algorithms/pag.py index 4f8f2a28..5d8dcfbe 100644 --- a/pywhy_graphs/algorithms/pag.py +++ b/pywhy_graphs/algorithms/pag.py @@ -6,7 +6,7 @@ import networkx as nx import numpy as np -from pywhy_graphs import PAG, StationaryTimeSeriesPAG +from pywhy_graphs import CPDAG, PAG, StationaryTimeSeriesPAG from pywhy_graphs.algorithms.generic import single_source_shortest_mixed_path from pywhy_graphs.typing import Node, TsNode @@ -908,3 +908,49 @@ def _check_ts_node(node): ) if node[1] > 0: raise ValueError(f"All lag points should be 0, or less. You passed in {node}.") + + +def pag_to_mag(graph): + """Convert an PAG to an MAG. + + Parameters + ---------- + G : Graph + The PAG. + + Returns + ------- + mag : Graph + The MAG constructed from the PAG. + """ + copy_graph = graph.copy() + + cedges = set(copy_graph.circle_edges) + dedges = set(copy_graph.directed_edges) + + temp_cpdag = CPDAG() + + to_remove = [] + to_reorient = [] + to_add = [] + + for u, v in cedges: + if (v, u) in dedges: # remove the circle end from a 'o-->' edge to make a '-->' edge + to_remove.append((u, v)) + elif (v, u) not in cedges: # reorient a '--o' edge to '-->' + to_reorient.append((u, v)) + elif (v, u) in cedges and ( + v, + u, + ) not in to_add: # add all 'o--o' edges to the cpdag + to_add.append((u, v)) + for u, v in to_remove: + copy_graph.remove_edge(u, v, graph.circle_edge_name) + for u, v in to_reorient: + copy_graph.orient_uncertain_edge(u, v) + for u, v in to_add: + temp_cpdag.add_edge(v, u, temp_cpdag.undirected_edge_name) + + # flag = True + + return copy_graph From 418330e6054dbb74b4ba07d149bc558a675840c1 Mon Sep 17 00:00:00 2001 From: Aryan Roy Date: Sun, 27 Aug 2023 13:49:33 +0530 Subject: [PATCH 02/22] refactored code and added skeleton for meek rules Signed-off-by: Aryan Roy --- pywhy_graphs/algorithms/pag.py | 33 ++++++++++++++--------- pywhy_graphs/algorithms/tests/test_pag.py | 16 +++++++++++ 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/pywhy_graphs/algorithms/pag.py b/pywhy_graphs/algorithms/pag.py index 5d8dcfbe..0da54330 100644 --- a/pywhy_graphs/algorithms/pag.py +++ b/pywhy_graphs/algorithms/pag.py @@ -22,6 +22,7 @@ "pds_t", "pds_t_path", "is_definite_noncollider", + "pag_to_mag", ] @@ -933,24 +934,32 @@ def pag_to_mag(graph): to_remove = [] to_reorient = [] to_add = [] - for u, v in cedges: if (v, u) in dedges: # remove the circle end from a 'o-->' edge to make a '-->' edge - to_remove.append((u, v)) + copy_graph.remove_edges(u,v) + #to_remove.append((u, v)) elif (v, u) not in cedges: # reorient a '--o' edge to '-->' - to_reorient.append((u, v)) + copy_graph.orient_uncertain_edge(u, v) elif (v, u) in cedges and ( v, u, ) not in to_add: # add all 'o--o' edges to the cpdag - to_add.append((u, v)) - for u, v in to_remove: - copy_graph.remove_edge(u, v, graph.circle_edge_name) - for u, v in to_reorient: - copy_graph.orient_uncertain_edge(u, v) - for u, v in to_add: - temp_cpdag.add_edge(v, u, temp_cpdag.undirected_edge_name) + CPDAG.add_edge(v, u, temp_cpdag.undirected_edge_name) + + flag = True + + # convert the graph into a DAG with no unshielded colliders + + while flag: + undedges = temp_cpdag.undirected_edges + if len(undedges) != 0: + for (u, v) in undedges: + temp_cpdag.remove_edge(u, v, temp_cpdag.undirected_edge_name) + temp_cpdag.add_edge(u, v, temp_cpdag.directed_edge_name) + #apply meek rules + break + else: + flag = False - # flag = True - return copy_graph + return None diff --git a/pywhy_graphs/algorithms/tests/test_pag.py b/pywhy_graphs/algorithms/tests/test_pag.py index c09607a8..352a51a1 100644 --- a/pywhy_graphs/algorithms/tests/test_pag.py +++ b/pywhy_graphs/algorithms/tests/test_pag.py @@ -647,3 +647,19 @@ def test_pdst(pdst_graph): ex_pdsep_t = pds_t(G, ("E", 0), ("x", -1)) assert ("y", -2) not in xe_pdsep_t assert ("y", -2) not in ex_pdsep_t + + +def test_pag_to_mag(): + G = PAG() + + # x o-o y o-o z, y o-o v + G.add_edge("y", "x", G.circle_edge_name) + G.add_edge("x", "y", G.circle_edge_name) + G.add_edge("z", "y", G.circle_edge_name) + G.add_edge("y", "z", G.circle_edge_name) + G.add_edge("v", "y", G.circle_edge_name) + G.add_edge("y", "v", G.circle_edge_name) + + new_g = pywhy_graphs.pag_to_mag(G) + print(new_g) + assert pywhy_graphs.valid_mag(new_g) From a0943aee45bfd9c3321601bd34c0356d2a954804 Mon Sep 17 00:00:00 2001 From: Aryan Roy Date: Tue, 29 Aug 2023 19:50:10 +0530 Subject: [PATCH 03/22] Added meek rules Signed-off-by: Aryan Roy --- pywhy_graphs/algorithms/pag.py | 203 ++++++++++++++++++++++++++++++++- 1 file changed, 198 insertions(+), 5 deletions(-) diff --git a/pywhy_graphs/algorithms/pag.py b/pywhy_graphs/algorithms/pag.py index 0da54330..717f1988 100644 --- a/pywhy_graphs/algorithms/pag.py +++ b/pywhy_graphs/algorithms/pag.py @@ -1,6 +1,6 @@ import logging from collections import deque -from itertools import chain +from itertools import chain, combinations, permutations from typing import List, Optional, Set, Tuple import networkx as nx @@ -911,6 +911,200 @@ def _check_ts_node(node): raise ValueError(f"All lag points should be 0, or less. You passed in {node}.") +def _apply_meek_rules(graph: CPDAG) -> None: + """Orient edges in a skeleton graph to estimate the causal DAG, or CPDAG. + These are known as the Meek rules :footcite:`Meek1995`. They are deterministic + in the sense that they are logical characterizations of what edges must be + present given the rest of the local graph structure. + Parameters + ---------- + graph : CPDAG + A graph containing directed and undirected edges. + """ + # For all the combination of nodes i and j, apply the following + # rules. + completed = False + while not completed: # type: ignore + change_flag = False + for i in graph.nodes: + for j in graph.neighbors(i): + if i == j: + continue + # Rule 1: Orient i-j into i->j whenever there is an arrow k->i + # such that k and j are nonadjacent. + r1_add = _meek_rule1(graph, i, j) + + # Rule 2: Orient i-j into i->j whenever there is a chain + # i->k->j. + r2_add = _meek_rule2(graph, i, j) + + # Rule 3: Orient i-j into i->j whenever there are two chains + # i-k->j and i-l->j such that k and l are nonadjacent. + r3_add = _meek_rule3(graph, i, j) + + # Rule 4: Orient i-j into i->j whenever there are two chains + # i-k->l and k->l->j such that k and j are nonadjacent. + # + r4_add = _meek_rule4(graph, i, j) + + if any([r1_add, r2_add, r3_add, r4_add]) and not change_flag: + change_flag = True + if not change_flag: + completed = True + break + + +def _meek_rule1(graph: CPDAG, i: str, j: str) -> bool: + """Apply rule 1 of Meek's rules. + Looks for i - j such that k -> i, such that (k,i,j) + is an unshielded triple. Then can orient i - j as i -> j. + """ + added_arrows = False + + # Check if i-j. + if graph.has_edge(i, j, graph.undirected_edge_name): + for k in graph.predecessors(i): + # Skip if k and j are adjacent because then it is a + # shielded triple + if j in graph.neighbors(k): + continue + + # check if the triple is in the graph's excluded triples + if frozenset((k, i, j)) in graph.excluded_triples: + continue + + # Make i-j into i->j + graph.orient_uncertain_edge(i, j) + + added_arrows = True + break + return added_arrows + + +def _meek_rule2(graph: CPDAG, i: str, j: str) -> bool: + """Apply rule 2 of Meek's rules. + Check for i - j, and then looks for i -> k -> j + triple, to orient i - j as i -> j. + """ + added_arrows = False + + # Check if i-j. + if graph.has_edge(i, j, graph.undirected_edge_name): + # Find nodes k where k is i->k + child_i = set() + for k in graph.successors(i): + if not graph.has_edge(k, i, graph.directed_edge_name): + child_i.add(k) + # Find nodes j where j is k->j. + parent_j = set() + for k in graph.predecessors(j): + if not graph.has_edge(j, k, graph.directed_edge_name): + parent_j.add(k) + + # Check if there is any node k where i->k->j. + candidate_k = child_i.intersection(parent_j) + # if the graph has excluded triples, we would check at this point + if graph.excluded_triples: + # check if the triple is in the graph's excluded triples + # if so, remove them from the candidates + for k in candidate_k: + if frozenset((i, k, j)) in graph.excluded_triples: + candidate_k.remove(k) + + # if there are candidate 'k' nodes, then orient the edge accordingly + if len(candidate_k) > 0: + # Make i-j into i->j + graph.orient_uncertain_edge(i, j) + added_arrows = True + return added_arrows + + +def _meek_rule3(graph: CPDAG, i: str, j: str) -> bool: + """Apply rule 3 of Meek's rules. + Check for i - j, and then looks for k -> j <- l + collider, and i - k and i - l, then orient i -> j. + """ + added_arrows = False + + # Check if i-j first + if graph.has_edge(i, j, graph.undirected_edge_name): + # For all the pairs of nodes adjacent to i, + # look for (k, l), such that j -> l and k -> l + for (k, l) in combinations(graph.neighbors(i), 2): + # Skip if k and l are adjacent. + if l in graph.neighbors(k): + continue + # Skip if not k->j. + if graph.has_edge(j, k, graph.directed_edge_name) or ( + not graph.has_edge(k, j, graph.directed_edge_name) + ): + continue + # Skip if not l->j. + if graph.has_edge(j, l, graph.directed_edge_name) or ( + not graph.has_edge(l, j, graph.directed_edge_name) + ): + continue + + # check if the triple is inside graph's excluded triples + if frozenset((l, i, k)) in graph.excluded_triples: + continue + + # if i - k and i - l, then at this point, we have a valid path + # to orient + if graph.has_edge(k, i, graph.undirected_edge_name) and graph.has_edge( + l, i, graph.undirected_edge_name + ): + graph.orient_uncertain_edge(i, j) + added_arrows = True + break + return added_arrows + + +def _meek_rule4(graph: CPDAG, i: str, j: str) -> bool: + """Apply rule 4 of Meek's rules. + Check for i - j, and then looks for i - k -> l -> j, to orient i - j as i -> j. + """ + added_arrows = False + + # Check if i-j. + if graph.has_edge(i, j, graph.undirected_edge_name): + # Find nodes k where k is i-k + adj_i = set() + for k in graph.neighbors(i): + if not graph.has_edge(k, i, graph.directed_edge_name): + adj_i.add(k) + + # Find nodes l where j is l->j. + parent_j = set() + for k in graph.predecessors(j): + if not graph.has_edge(j, k, graph.directed_edge_name): + parent_j.add(k) + + # generate all permutations of sets containing neighbors of i and parents of j + permut = permutations(adj_i, len(parent_j)) + unq = set() # type: ignore + for comb in permut: + zipped = zip(comb, parent_j) + unq.update(zipped) + + # check if these pairs have a directed edge between them and that k-j does not exist + dedges = set(graph.directed_edges) + undedges = set(graph.undirected_edges) + candidate_k = set() + for pair in unq: + if pair in dedges: + if (pair[0], j) not in undedges: + candidate_k.add(pair) + + # if there are candidate 'k->l' pairs, then orient the edge accordingly + if len(candidate_k) > 0: + # Make i-j into i->j + # logger.info(f"R2: Removing edge {i}-{j} to form {i}->{j}.") + graph.orient_uncertain_edge(i, j) + added_arrows = True + return added_arrows + + def pag_to_mag(graph): """Convert an PAG to an MAG. @@ -936,8 +1130,8 @@ def pag_to_mag(graph): to_add = [] for u, v in cedges: if (v, u) in dedges: # remove the circle end from a 'o-->' edge to make a '-->' edge - copy_graph.remove_edges(u,v) - #to_remove.append((u, v)) + copy_graph.remove_edges(u, v) + # to_remove.append((u, v)) elif (v, u) not in cedges: # reorient a '--o' edge to '-->' copy_graph.orient_uncertain_edge(u, v) elif (v, u) in cedges and ( @@ -956,10 +1150,9 @@ def pag_to_mag(graph): for (u, v) in undedges: temp_cpdag.remove_edge(u, v, temp_cpdag.undirected_edge_name) temp_cpdag.add_edge(u, v, temp_cpdag.directed_edge_name) - #apply meek rules + _apply_meek_rules(temp_cpdag) break else: flag = False - return None From 8f0daf68b02d88e71843efb583e4bc64732484db Mon Sep 17 00:00:00 2001 From: Aryan Roy Date: Sat, 2 Sep 2023 20:39:36 +0530 Subject: [PATCH 04/22] Completed pag_to_mag function Signed-off-by: Aryan Roy --- pywhy_graphs/algorithms/pag.py | 34 ++++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/pywhy_graphs/algorithms/pag.py b/pywhy_graphs/algorithms/pag.py index 717f1988..083ec466 100644 --- a/pywhy_graphs/algorithms/pag.py +++ b/pywhy_graphs/algorithms/pag.py @@ -6,7 +6,7 @@ import networkx as nx import numpy as np -from pywhy_graphs import CPDAG, PAG, StationaryTimeSeriesPAG +from pywhy_graphs import ADMG, CPDAG, PAG, StationaryTimeSeriesPAG from pywhy_graphs.algorithms.generic import single_source_shortest_mixed_path from pywhy_graphs.typing import Node, TsNode @@ -1128,17 +1128,23 @@ def pag_to_mag(graph): to_remove = [] to_reorient = [] to_add = [] + for u, v in cedges: if (v, u) in dedges: # remove the circle end from a 'o-->' edge to make a '-->' edge - copy_graph.remove_edges(u, v) - # to_remove.append((u, v)) + to_remove.append((u, v)) elif (v, u) not in cedges: # reorient a '--o' edge to '-->' - copy_graph.orient_uncertain_edge(u, v) + to_reorient.append((u, v)) elif (v, u) in cedges and ( v, u, ) not in to_add: # add all 'o--o' edges to the cpdag - CPDAG.add_edge(v, u, temp_cpdag.undirected_edge_name) + to_add.append((u, v)) + for u, v in to_remove: + copy_graph.remove_edge(u, v, graph.circle_edge_name) + for u, v in to_reorient: + copy_graph.orient_uncertain_edge(u, v) + for u, v in to_add: + temp_cpdag.add_edge(v, u, temp_cpdag.undirected_edge_name) flag = True @@ -1155,4 +1161,20 @@ def pag_to_mag(graph): else: flag = False - return None + mag = ADMG() # provisional MAG + + # construct the final MAG + + for (u, v) in copy_graph.directed_edges: + mag.add_edge(u, v, mag.directed_edge_name) + + for (u, v) in copy_graph.undirected_edges: + mag.add_edge(u, v, mag.undirected_edge_name) + + for (u, v) in copy_graph.bidirected_edges: + mag.add_edge(u, v, mag.bidirected_edge_name) + + for (u, v) in temp_cpdag.directed_edges: + mag.add_edge(u, v, mag.directed_edge_name) + + return mag From 89ce696138aaf464daf18e84715c735a9cda1209 Mon Sep 17 00:00:00 2001 From: Aryan Roy Date: Wed, 4 Oct 2023 21:25:03 +0530 Subject: [PATCH 05/22] Added first test Signed-off-by: Aryan Roy --- pywhy_graphs/algorithms/tests/test_generic.py | 2 +- pywhy_graphs/algorithms/tests/test_pag.py | 44 +++++++++++++------ 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/pywhy_graphs/algorithms/tests/test_generic.py b/pywhy_graphs/algorithms/tests/test_generic.py index 0f717893..c4018181 100644 --- a/pywhy_graphs/algorithms/tests/test_generic.py +++ b/pywhy_graphs/algorithms/tests/test_generic.py @@ -2,7 +2,7 @@ import pytest import pywhy_graphs -from pywhy_graphs import ADMG +from pywhy_graphs import ADMG, PAG def test_convert_to_latent_confounder_errors(): diff --git a/pywhy_graphs/algorithms/tests/test_pag.py b/pywhy_graphs/algorithms/tests/test_pag.py index 352a51a1..9d1521c8 100644 --- a/pywhy_graphs/algorithms/tests/test_pag.py +++ b/pywhy_graphs/algorithms/tests/test_pag.py @@ -3,7 +3,7 @@ import pytest import pywhy_graphs -from pywhy_graphs import PAG +from pywhy_graphs import ADMG, PAG from pywhy_graphs.algorithms import ( discriminating_path, is_definite_noncollider, @@ -650,16 +650,32 @@ def test_pdst(pdst_graph): def test_pag_to_mag(): - G = PAG() - - # x o-o y o-o z, y o-o v - G.add_edge("y", "x", G.circle_edge_name) - G.add_edge("x", "y", G.circle_edge_name) - G.add_edge("z", "y", G.circle_edge_name) - G.add_edge("y", "z", G.circle_edge_name) - G.add_edge("v", "y", G.circle_edge_name) - G.add_edge("y", "v", G.circle_edge_name) - - new_g = pywhy_graphs.pag_to_mag(G) - print(new_g) - assert pywhy_graphs.valid_mag(new_g) + pag = PAG() + pag.add_edge("A", "D", pag.directed_edge_name) + pag.add_edge("A", "C", pag.circle_edge_name) + pag.add_edge("D", "A", pag.circle_edge_name) + pag.add_edge("B", "D", pag.directed_edge_name) + pag.add_edge("C", "D", pag.directed_edge_name) + pag.add_edge("D", "B", pag.circle_edge_name) + pag.add_edge("D", "C", pag.circle_edge_name) + pag.add_edge("C", "A", pag.circle_edge_name) + pag.add_edge("B", "A", pag.circle_edge_name) + pag.add_edge("A", "B", pag.circle_edge_name) + + out_mag = pywhy_graphs.pag_to_mag(pag) + + mag = ADMG() + mag.add_edge("A", "B", mag.directed_edge_name) + mag.add_edge("A", "C", mag.directed_edge_name) + mag.add_edge("A", "D", mag.directed_edge_name) + mag.add_edge("B", "D", mag.directed_edge_name) + mag.add_edge("C", "D", mag.directed_edge_name) + + out_edges = list(out_mag.edges()["directed"]) + assert ( + (("A", "B") in out_edges) + and (("A", "C") in out_edges) + and (("A", "D") in out_edges) + and (("B", "D") in out_edges) + and (("C", "D") in out_edges) + ) From ac8cb5e2be00bc31ad112fc4170cf0c755562f54 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 6 Sep 2023 09:13:51 -0600 Subject: [PATCH 06/22] Bump actions/checkout from 3 to 4 (#94) Bumps [actions/checkout](https://github.com/actions/checkout) from 3 to 4. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/checkout dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/main.yml | 8 ++++---- .github/workflows/pr_checks.yml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index c3afab7d..62236269 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -23,7 +23,7 @@ jobs: poetry-version: [1.5.1] steps: - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup Python 3.9 uses: actions/setup-python@v4 with: @@ -65,7 +65,7 @@ jobs: shell: bash steps: - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: @@ -123,7 +123,7 @@ jobs: shell: bash steps: - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: @@ -168,7 +168,7 @@ jobs: if: startsWith(github.ref, 'refs/tags/') steps: - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: diff --git a/.github/workflows/pr_checks.yml b/.github/workflows/pr_checks.yml index 5a8682c3..d2ef9655 100644 --- a/.github/workflows/pr_checks.yml +++ b/.github/workflows/pr_checks.yml @@ -22,7 +22,7 @@ jobs: run: | echo "PR_NUMBER=${{ github.event.pull_request.number }}" >> $GITHUB_ENV echo "TAGGED_MILESTONE=${{ github.event.pull_request.milestone.title }}" >> $GITHUB_ENV - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: '0' - name: Check that CHANGELOG has been updated From f5398d914b29b57449ed7aae5bde0d33d04df546 Mon Sep 17 00:00:00 2001 From: Aryan Roy <50577809+aryan26roy@users.noreply.github.com> Date: Tue, 26 Sep 2023 21:46:46 +0530 Subject: [PATCH 07/22] [ENH] Add the ability to convert a DAG to an MAG (#96) * Add is_maximal function * add DAG to MAG function --------- Signed-off-by: Aryan Roy Co-authored-by: Adam Li --- doc/api.rst | 1 + doc/whats_new/v0.2.rst | 1 + pywhy_graphs/algorithms/generic.py | 120 +++++++++++++++++ pywhy_graphs/algorithms/tests/test_generic.py | 124 ++++++++++++++++++ 4 files changed, 246 insertions(+) diff --git a/doc/api.rst b/doc/api.rst index dfa8b06d..d5974823 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -43,6 +43,7 @@ causal graph operations. .. autosummary:: :toctree: generated/ + dag_to_mag valid_mag has_adc inducing_path diff --git a/doc/whats_new/v0.2.rst b/doc/whats_new/v0.2.rst index 2bfec356..bc7bcf9b 100644 --- a/doc/whats_new/v0.2.rst +++ b/doc/whats_new/v0.2.rst @@ -26,6 +26,7 @@ Version 0.2 Changelog --------- - |Feature| Implement and test functions to validate an MAG and check the presence of almost directed cycles, by `Aryan Roy`_ (:pr:`91`) +- |Feature| Implement and test functions to convert a DAG to MAG, by `Aryan Roy`_ (:pr:`96`) Code and Documentation Contributors ----------------------------------- diff --git a/pywhy_graphs/algorithms/generic.py b/pywhy_graphs/algorithms/generic.py index 27d41211..cdee0be1 100644 --- a/pywhy_graphs/algorithms/generic.py +++ b/pywhy_graphs/algorithms/generic.py @@ -15,6 +15,8 @@ "inducing_path", "has_adc", "valid_mag", + "dag_to_mag", + "is_maximal", ] @@ -567,6 +569,9 @@ def inducing_path(G, node_x: Node, node_y: Node, L: Set = None, S: Set = None): if node_x == node_y: raise ValueError("The source and destination nodes are the same.") + if (node_x in L) or (node_y in L) or (node_x in S) or (node_y in S): + return (False, []) + edges = G.edges() # XXX: fix this when graphs are refactored to only check for directed/bidirected edge types @@ -703,3 +708,118 @@ def valid_mag(G: ADMG, L: set = None, S: set = None): return False return True + + +def dag_to_mag(G, L: Set = None, S: Set = None): + """Converts a DAG to a valid MAG. + + The algorithm is defined in :footcite:`Zhang2008` on page 1877. + + Parameters: + ----------- + G : Graph + The graph. + L : Set + Nodes that are ignored on the path. Defaults to an empty set. + S : Set + Nodes that are always conditioned on. Defaults to an empty set. + + Returns + ------- + mag : Graph + The MAG. + """ + + if L is None: + L = set() + + if S is None: + S = set() + + # for each pair of nodes find if they have an inducing path between them. + # only then will they be adjacent in the MAG. + + all_nodes = set(G.nodes) + adj_nodes = [] + + for source in all_nodes: + copy_all = all_nodes.copy() + copy_all.remove(source) + for dest in copy_all: + out = inducing_path(G, source, dest, L, S) + if out[0] is True and {source, dest} not in adj_nodes: + adj_nodes.append({source, dest}) + + # find the ancestors of B U S (ansB) and A U S (ansA) for each pair of adjacent nodes + + mag = ADMG() + + for A, B in adj_nodes: + + AuS = S.union(A) + BuS = S.union(B) + + ansA: Set = set() + ansB: Set = set() + + for node in AuS: + ansA = ansA.union(_directed_sub_graph_ancestors(G, node)) + + for node in BuS: + ansB = ansB.union(_directed_sub_graph_ancestors(G, node)) + + if A in ansB and B not in ansA: + # if A is in ansB and B is not in ansA, A -> B + mag.add_edge(A, B, mag.directed_edge_name) + + elif A not in ansB and B in ansA: + # if B is in ansA and A is not in ansB, A <- B + mag.add_edge(B, A, mag.directed_edge_name) + + elif A not in ansB and B not in ansA: + # if A is not in ansB and B is not in ansA, A <-> B + mag.add_edge(B, A, mag.bidirected_edge_name) + + elif A in ansB and B in ansA: + # if A is in ansB and B is in ansA, A - B + mag.add_edge(B, A, mag.undirected_edge_name) + + return mag + + +def is_maximal(G, L: Set = None, S: Set = None): + """Checks to see if the graph is maximal. + + Parameters: + ----------- + G : Graph + The graph. + + Returns + ------- + is_maximal : bool + A boolean indicating whether the provided graph is maximal or not. + """ + + if L is None: + L = set() + + if S is None: + S = set() + + all_nodes = set(G.nodes) + checked = set() + for source in all_nodes: + nb = set(G.neighbors(source)) + cur_set = all_nodes - nb + cur_set.remove(source) + for dest in cur_set: + current_pair = frozenset({source, dest}) + if current_pair not in checked: + checked.add(current_pair) + out = inducing_path(G, source, dest, L, S) + if out[0] is True: + return False + else: + continue + return True diff --git a/pywhy_graphs/algorithms/tests/test_generic.py b/pywhy_graphs/algorithms/tests/test_generic.py index c4018181..670d56b0 100644 --- a/pywhy_graphs/algorithms/tests/test_generic.py +++ b/pywhy_graphs/algorithms/tests/test_generic.py @@ -214,6 +214,30 @@ def test_inducing_path_corner_cases(): assert pywhy_graphs.inducing_path(admg, "X", "Y", L, S)[0] + # X -> Z <- Y, A <- B <- Z + admg = ADMG() + admg.add_edge("X", "Z", admg.directed_edge_name) + admg.add_edge("Y", "Z", admg.directed_edge_name) + admg.add_edge("Z", "B", admg.directed_edge_name) + admg.add_edge("B", "A", admg.directed_edge_name) + + L = {"X"} + S = {"A"} + + assert not pywhy_graphs.inducing_path(admg, "X", "Y", L, S)[0] + + # X -> Z <- Y, A <- B <- Z + admg = ADMG() + admg.add_edge("X", "Z", admg.directed_edge_name) + admg.add_edge("Y", "Z", admg.directed_edge_name) + admg.add_edge("Z", "B", admg.directed_edge_name) + admg.add_edge("B", "A", admg.directed_edge_name) + + L = {} + S = {"A", "Y"} + + assert not pywhy_graphs.inducing_path(admg, "X", "Y", L, S)[0] + def test_is_collider(): # Z -> X -> A <- B -> Y; H -> A @@ -348,3 +372,103 @@ def test_valid_mag(): admg.add_edge("H", "J", admg.undirected_edge_name) assert not pywhy_graphs.valid_mag(admg) # there is an undirected edge between H and J + + +def test_dag_to_mag(): + + # A -> E -> S + # H -> E , H -> R + admg = ADMG() + admg.add_edge("A", "E", admg.directed_edge_name) + admg.add_edge("E", "S", admg.directed_edge_name) + admg.add_edge("H", "E", admg.directed_edge_name) + admg.add_edge("H", "R", admg.directed_edge_name) + + S = {"S"} + L = {"H"} + + out_mag = pywhy_graphs.dag_to_mag(admg, L, S) + assert pywhy_graphs.is_maximal(out_mag) + assert not pywhy_graphs.has_adc(out_mag) + out_edges = out_mag.edges() + dir_edges = list(out_edges["directed"]) + assert ( + ("A", "R") in out_edges["directed"] + and ("E", "R") in out_edges["directed"] + and len(out_edges["directed"]) == 2 + ) + assert ("A", "E") in out_edges["undirected"] + + out_mag = pywhy_graphs.dag_to_mag(admg) + dir_edges = list(out_mag.edges()["directed"]) + + assert ( + ("A", "E") in dir_edges + and ("E", "S") in dir_edges + and ("H", "E") in dir_edges + and ("H", "R") in dir_edges + ) + + # A -> E -> S <- H + # H -> E , H -> R, + + admg = ADMG() + admg.add_edge("A", "E", admg.directed_edge_name) + admg.add_edge("H", "S", admg.directed_edge_name) + admg.add_edge("H", "E", admg.directed_edge_name) + admg.add_edge("H", "R", admg.directed_edge_name) + + S = {"S"} + L = {"H"} + + out_mag = pywhy_graphs.dag_to_mag(admg, L, S) + assert pywhy_graphs.is_maximal(out_mag) + assert not pywhy_graphs.has_adc(out_mag) + out_edges = out_mag.edges() + + dir_edges = list(out_edges["directed"]) + assert ("A", "E") in out_edges["directed"] and len(out_edges["directed"]) == 1 + assert ("E", "R") in out_edges["bidirected"] + + # P -> S -> L <- G + # G -> S -> I <- J + # J -> S + + admg = ADMG() + admg.add_edge("P", "S", admg.directed_edge_name) + admg.add_edge("S", "L", admg.directed_edge_name) + admg.add_edge("G", "S", admg.directed_edge_name) + admg.add_edge("G", "L", admg.directed_edge_name) + admg.add_edge("I", "S", admg.directed_edge_name) + admg.add_edge("J", "I", admg.directed_edge_name) + admg.add_edge("J", "S", admg.directed_edge_name) + + S = set() + L = {"J"} + + out_mag = pywhy_graphs.dag_to_mag(admg, L, S) + assert pywhy_graphs.is_maximal(out_mag) + assert not pywhy_graphs.has_adc(out_mag) + out_edges = out_mag.edges() + dir_edges = list(out_edges["directed"]) + assert ( + ("G", "S") in dir_edges + and ("G", "L") in dir_edges + and ("S", "L") in dir_edges + and ("I", "S") in dir_edges + and ("P", "S") in dir_edges + and len(dir_edges) == 5 + ) + + +def test_is_maximal(): + # X <- Y <-> Z <-> H; Z -> X + admg = ADMG() + admg.add_edge("Y", "X", admg.directed_edge_name) + admg.add_edge("Z", "X", admg.directed_edge_name) + admg.add_edge("Z", "Y", admg.bidirected_edge_name) + admg.add_edge("Z", "H", admg.bidirected_edge_name) + + S = {} + L = {"Y"} + assert not pywhy_graphs.is_maximal(admg, L, S) From 183f0bff60881ab033fd8d9debfc7bc45ccd5f80 Mon Sep 17 00:00:00 2001 From: Aryan Roy Date: Wed, 4 Oct 2023 21:26:54 +0530 Subject: [PATCH 08/22] Linting Signed-off-by: Aryan Roy --- pywhy_graphs/algorithms/tests/test_generic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pywhy_graphs/algorithms/tests/test_generic.py b/pywhy_graphs/algorithms/tests/test_generic.py index 670d56b0..a9876b57 100644 --- a/pywhy_graphs/algorithms/tests/test_generic.py +++ b/pywhy_graphs/algorithms/tests/test_generic.py @@ -2,7 +2,7 @@ import pytest import pywhy_graphs -from pywhy_graphs import ADMG, PAG +from pywhy_graphs import ADMG def test_convert_to_latent_confounder_errors(): From a76bddca8d7be2fa76ca750d43e38902b57e75ab Mon Sep 17 00:00:00 2001 From: Aryan Roy Date: Thu, 5 Oct 2023 21:18:12 +0530 Subject: [PATCH 09/22] Added information to docstring Signed-off-by: Aryan Roy --- pywhy_graphs/algorithms/pag.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pywhy_graphs/algorithms/pag.py b/pywhy_graphs/algorithms/pag.py index 083ec466..203d7b7b 100644 --- a/pywhy_graphs/algorithms/pag.py +++ b/pywhy_graphs/algorithms/pag.py @@ -1106,7 +1106,10 @@ def _meek_rule4(graph: CPDAG, i: str, j: str) -> bool: def pag_to_mag(graph): - """Convert an PAG to an MAG. + """Convert a PAG to an MAG using algorithm defined in Theorem 2 of + defined in :footcite:`Zhang2008`. + The algorithm turns all o-> edges to -> and -o edges to ->. Then converts the + input graph into a DAG with no unshielded colliders. Parameters ---------- From 043b872badfdb8a4f27972e7ba27af6dc049279f Mon Sep 17 00:00:00 2001 From: Aryan Roy Date: Sat, 7 Oct 2023 18:32:03 +0530 Subject: [PATCH 10/22] Added checks for markov equivalent graphs in tests Signed-off-by: Aryan Roy --- pywhy_graphs/algorithms/pag.py | 1 + pywhy_graphs/algorithms/tests/test_pag.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pywhy_graphs/algorithms/pag.py b/pywhy_graphs/algorithms/pag.py index 203d7b7b..3af6362c 100644 --- a/pywhy_graphs/algorithms/pag.py +++ b/pywhy_graphs/algorithms/pag.py @@ -1167,6 +1167,7 @@ def pag_to_mag(graph): mag = ADMG() # provisional MAG # construct the final MAG + print(temp_cpdag.edges()) for (u, v) in copy_graph.directed_edges: mag.add_edge(u, v, mag.directed_edge_name) diff --git a/pywhy_graphs/algorithms/tests/test_pag.py b/pywhy_graphs/algorithms/tests/test_pag.py index 9d1521c8..b0e8d5aa 100644 --- a/pywhy_graphs/algorithms/tests/test_pag.py +++ b/pywhy_graphs/algorithms/tests/test_pag.py @@ -673,8 +673,8 @@ def test_pag_to_mag(): out_edges = list(out_mag.edges()["directed"]) assert ( - (("A", "B") in out_edges) - and (("A", "C") in out_edges) + ((("A", "B") in out_edges) or (("B", "A") in out_edges)) + and ((("A", "C") in out_edges) or (("C", "A") in out_edges)) and (("A", "D") in out_edges) and (("B", "D") in out_edges) and (("C", "D") in out_edges) From 0766d6b3ef4a2eca1419591d26ed847452f82ddf Mon Sep 17 00:00:00 2001 From: Aryan Roy Date: Sat, 7 Oct 2023 19:56:49 +0530 Subject: [PATCH 11/22] Added another test Signed-off-by: Aryan Roy --- pywhy_graphs/algorithms/tests/test_pag.py | 25 ++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/pywhy_graphs/algorithms/tests/test_pag.py b/pywhy_graphs/algorithms/tests/test_pag.py index b0e8d5aa..cad6ac72 100644 --- a/pywhy_graphs/algorithms/tests/test_pag.py +++ b/pywhy_graphs/algorithms/tests/test_pag.py @@ -664,13 +664,6 @@ def test_pag_to_mag(): out_mag = pywhy_graphs.pag_to_mag(pag) - mag = ADMG() - mag.add_edge("A", "B", mag.directed_edge_name) - mag.add_edge("A", "C", mag.directed_edge_name) - mag.add_edge("A", "D", mag.directed_edge_name) - mag.add_edge("B", "D", mag.directed_edge_name) - mag.add_edge("C", "D", mag.directed_edge_name) - out_edges = list(out_mag.edges()["directed"]) assert ( ((("A", "B") in out_edges) or (("B", "A") in out_edges)) @@ -679,3 +672,21 @@ def test_pag_to_mag(): and (("B", "D") in out_edges) and (("C", "D") in out_edges) ) + + pag = PAG() + pag.add_edge("A", "B", pag.circle_edge_name) + pag.add_edge("B", "A", pag.directed_edge_name) + pag.add_edge("D", "A", pag.directed_edge_name) + pag.add_edge("A", "D", pag.circle_edge_name) + pag.add_edge("D", "B", pag.circle_edge_name) + pag.add_edge("B", "D", pag.circle_edge_name) + + out_mag = pywhy_graphs.pag_to_mag(pag) + + mag = ADMG() + mag.add_edge("B", "A", mag.directed_edge_name) + mag.add_edge("D", "A", mag.directed_edge_name) + mag.add_edge("D", "B", mag.directed_edge_name) + + out_edges = list(out_mag.edges()["directed"]) + assert (("B", "A") in out_edges) and (("D", "A") in out_edges) and (("D", "B") in out_edges) From 4fdaa89352ba61730e2a8efb717b120926999089 Mon Sep 17 00:00:00 2001 From: Aryan Roy Date: Sat, 7 Oct 2023 20:12:08 +0530 Subject: [PATCH 12/22] Added CHANGELOG Signed-off-by: Aryan Roy --- doc/whats_new/v0.2.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/whats_new/v0.2.rst b/doc/whats_new/v0.2.rst index bc7bcf9b..d08bc8a1 100644 --- a/doc/whats_new/v0.2.rst +++ b/doc/whats_new/v0.2.rst @@ -27,6 +27,7 @@ Changelog --------- - |Feature| Implement and test functions to validate an MAG and check the presence of almost directed cycles, by `Aryan Roy`_ (:pr:`91`) - |Feature| Implement and test functions to convert a DAG to MAG, by `Aryan Roy`_ (:pr:`96`) +- |Feature| Implement and test functions to convert a PAG to MAG, by `Aryan Roy`_ (:pr:`93`) Code and Documentation Contributors ----------------------------------- From 9b34e6f3a5e11085c1b4de82b1686922cc9c1104 Mon Sep 17 00:00:00 2001 From: Aryan Roy Date: Tue, 10 Oct 2023 19:42:10 +0530 Subject: [PATCH 13/22] Fixed test Signed-off-by: Aryan Roy --- pywhy_graphs/algorithms/tests/test_pag.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pywhy_graphs/algorithms/tests/test_pag.py b/pywhy_graphs/algorithms/tests/test_pag.py index cad6ac72..5642c5e8 100644 --- a/pywhy_graphs/algorithms/tests/test_pag.py +++ b/pywhy_graphs/algorithms/tests/test_pag.py @@ -689,4 +689,8 @@ def test_pag_to_mag(): mag.add_edge("D", "B", mag.directed_edge_name) out_edges = list(out_mag.edges()["directed"]) - assert (("B", "A") in out_edges) and (("D", "A") in out_edges) and (("D", "B") in out_edges) + assert ( + (("B", "A") in out_edges) + and (("D", "A") in out_edges) + and ((("D", "B") in out_edges) or (("B", "D") in out_edges)) + ) From 0825a8f6fd48859b2001f3ce562ba38c1f5bd8f8 Mon Sep 17 00:00:00 2001 From: Aryan Roy Date: Tue, 10 Oct 2023 19:55:27 +0530 Subject: [PATCH 14/22] Using has_edge API Signed-off-by: Aryan Roy --- pywhy_graphs/algorithms/tests/test_pag.py | 25 ++++++++--------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/pywhy_graphs/algorithms/tests/test_pag.py b/pywhy_graphs/algorithms/tests/test_pag.py index 5642c5e8..d7815bcf 100644 --- a/pywhy_graphs/algorithms/tests/test_pag.py +++ b/pywhy_graphs/algorithms/tests/test_pag.py @@ -3,7 +3,7 @@ import pytest import pywhy_graphs -from pywhy_graphs import ADMG, PAG +from pywhy_graphs import PAG from pywhy_graphs.algorithms import ( discriminating_path, is_definite_noncollider, @@ -664,13 +664,12 @@ def test_pag_to_mag(): out_mag = pywhy_graphs.pag_to_mag(pag) - out_edges = list(out_mag.edges()["directed"]) assert ( - ((("A", "B") in out_edges) or (("B", "A") in out_edges)) - and ((("A", "C") in out_edges) or (("C", "A") in out_edges)) - and (("A", "D") in out_edges) - and (("B", "D") in out_edges) - and (("C", "D") in out_edges) + ((out_mag.has_edge("A", "B")) or (out_mag.has_edge("B", "A"))) + and ((out_mag.has_edge("A", "C")) or (out_mag.has_edge("C", "A"))) + and (out_mag.has_edge("A", "D")) + and (out_mag.has_edge("B", "D")) + and (out_mag.has_edge("C", "D")) ) pag = PAG() @@ -683,14 +682,8 @@ def test_pag_to_mag(): out_mag = pywhy_graphs.pag_to_mag(pag) - mag = ADMG() - mag.add_edge("B", "A", mag.directed_edge_name) - mag.add_edge("D", "A", mag.directed_edge_name) - mag.add_edge("D", "B", mag.directed_edge_name) - - out_edges = list(out_mag.edges()["directed"]) assert ( - (("B", "A") in out_edges) - and (("D", "A") in out_edges) - and ((("D", "B") in out_edges) or (("B", "D") in out_edges)) + out_mag.has_edge("B", "A") + and out_mag.has_edge("D", "A") + and (out_mag.has_edge("D", "B") or out_mag.has_edge("B", "D")) ) From 41fc2ff838186c5f7d9eff19c6631fa075975e33 Mon Sep 17 00:00:00 2001 From: Aryan Roy Date: Wed, 11 Oct 2023 18:36:01 +0530 Subject: [PATCH 15/22] Fixed a bug in implementation Signed-off-by: Aryan Roy --- pywhy_graphs/algorithms/pag.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pywhy_graphs/algorithms/pag.py b/pywhy_graphs/algorithms/pag.py index 3af6362c..8395acd5 100644 --- a/pywhy_graphs/algorithms/pag.py +++ b/pywhy_graphs/algorithms/pag.py @@ -1143,7 +1143,7 @@ def pag_to_mag(graph): ) not in to_add: # add all 'o--o' edges to the cpdag to_add.append((u, v)) for u, v in to_remove: - copy_graph.remove_edge(u, v, graph.circle_edge_name) + copy_graph.remove_edge(u, v, copy_graph.circle_edge_name) for u, v in to_reorient: copy_graph.orient_uncertain_edge(u, v) for u, v in to_add: @@ -1172,13 +1172,11 @@ def pag_to_mag(graph): for (u, v) in copy_graph.directed_edges: mag.add_edge(u, v, mag.directed_edge_name) - for (u, v) in copy_graph.undirected_edges: - mag.add_edge(u, v, mag.undirected_edge_name) - - for (u, v) in copy_graph.bidirected_edges: - mag.add_edge(u, v, mag.bidirected_edge_name) for (u, v) in temp_cpdag.directed_edges: mag.add_edge(u, v, mag.directed_edge_name) + + for (u, v) in temp_cpdag.undirected_edges: + mag.add_edge(u, v, mag.directed_edge_name) return mag From 7bec4b8a6abd206860d4edbc9b1467ac94363da6 Mon Sep 17 00:00:00 2001 From: Aryan Roy Date: Wed, 18 Oct 2023 22:25:44 +0530 Subject: [PATCH 16/22] Added another test Signed-off-by: Aryan Roy --- pywhy_graphs/algorithms/pag.py | 1 - pywhy_graphs/algorithms/tests/test_pag.py | 19 +++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/pywhy_graphs/algorithms/pag.py b/pywhy_graphs/algorithms/pag.py index 8395acd5..346bff66 100644 --- a/pywhy_graphs/algorithms/pag.py +++ b/pywhy_graphs/algorithms/pag.py @@ -1167,7 +1167,6 @@ def pag_to_mag(graph): mag = ADMG() # provisional MAG # construct the final MAG - print(temp_cpdag.edges()) for (u, v) in copy_graph.directed_edges: mag.add_edge(u, v, mag.directed_edge_name) diff --git a/pywhy_graphs/algorithms/tests/test_pag.py b/pywhy_graphs/algorithms/tests/test_pag.py index d7815bcf..0bbf661f 100644 --- a/pywhy_graphs/algorithms/tests/test_pag.py +++ b/pywhy_graphs/algorithms/tests/test_pag.py @@ -687,3 +687,22 @@ def test_pag_to_mag(): and out_mag.has_edge("D", "A") and (out_mag.has_edge("D", "B") or out_mag.has_edge("B", "D")) ) + + pag = PAG() + pag.add_edge("A", "B", pag.directed_edge_name) + pag.add_edge("C", "B", pag.directed_edge_name) + pag.add_edge("E", "B", pag.directed_edge_name) + pag.add_edge("E", "D", pag.circle_edge_name) + pag.add_edge("C", "D", pag.circle_edge_name) + pag.add_edge("D", "E", pag.circle_edge_name) + pag.add_edge("D", "C", pag.circle_edge_name) + + out_mag = pywhy_graphs.pag_to_mag(pag) + + assert ( + out_mag.has_edge("A", "B") + and out_mag.has_edge("C", "B") + and out_mag.has_edge("E", "B") + and (out_mag.has_edge("E", "D") or out_mag.has_edge("D", "E")) + and (out_mag.has_edge("D","C") or out_mag.has_edge("C","D")) + ) \ No newline at end of file From 0a7c1f9de251498213cadfe6246c0847f82d5953 Mon Sep 17 00:00:00 2001 From: Aryan Roy Date: Wed, 18 Oct 2023 22:27:57 +0530 Subject: [PATCH 17/22] Linting Signed-off-by: Aryan Roy --- pywhy_graphs/algorithms/pag.py | 3 +-- pywhy_graphs/algorithms/tests/test_pag.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pywhy_graphs/algorithms/pag.py b/pywhy_graphs/algorithms/pag.py index 346bff66..faec01c2 100644 --- a/pywhy_graphs/algorithms/pag.py +++ b/pywhy_graphs/algorithms/pag.py @@ -1171,10 +1171,9 @@ def pag_to_mag(graph): for (u, v) in copy_graph.directed_edges: mag.add_edge(u, v, mag.directed_edge_name) - for (u, v) in temp_cpdag.directed_edges: mag.add_edge(u, v, mag.directed_edge_name) - + for (u, v) in temp_cpdag.undirected_edges: mag.add_edge(u, v, mag.directed_edge_name) diff --git a/pywhy_graphs/algorithms/tests/test_pag.py b/pywhy_graphs/algorithms/tests/test_pag.py index 0bbf661f..0d522a65 100644 --- a/pywhy_graphs/algorithms/tests/test_pag.py +++ b/pywhy_graphs/algorithms/tests/test_pag.py @@ -704,5 +704,5 @@ def test_pag_to_mag(): and out_mag.has_edge("C", "B") and out_mag.has_edge("E", "B") and (out_mag.has_edge("E", "D") or out_mag.has_edge("D", "E")) - and (out_mag.has_edge("D","C") or out_mag.has_edge("C","D")) - ) \ No newline at end of file + and (out_mag.has_edge("D", "C") or out_mag.has_edge("C", "D")) + ) From 746298157b265369dfa776bc985bbe44108ed525 Mon Sep 17 00:00:00 2001 From: Aryan Roy Date: Mon, 23 Oct 2023 11:05:31 +0530 Subject: [PATCH 18/22] Fixed a bug Signed-off-by: Aryan Roy --- pywhy_graphs/algorithms/pag.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pywhy_graphs/algorithms/pag.py b/pywhy_graphs/algorithms/pag.py index faec01c2..fb9f2cd3 100644 --- a/pywhy_graphs/algorithms/pag.py +++ b/pywhy_graphs/algorithms/pag.py @@ -1174,7 +1174,4 @@ def pag_to_mag(graph): for (u, v) in temp_cpdag.directed_edges: mag.add_edge(u, v, mag.directed_edge_name) - for (u, v) in temp_cpdag.undirected_edges: - mag.add_edge(u, v, mag.directed_edge_name) - return mag From 1bc397ea10686478a78a98b00ae232740a46da54 Mon Sep 17 00:00:00 2001 From: Aryan Roy <50577809+aryan26roy@users.noreply.github.com> Date: Mon, 30 Oct 2023 22:01:06 +0530 Subject: [PATCH 19/22] Update pywhy_graphs/algorithms/pag.py Co-authored-by: Adam Li Signed-off-by: Aryan Roy <50577809+aryan26roy@users.noreply.github.com> --- pywhy_graphs/algorithms/pag.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pywhy_graphs/algorithms/pag.py b/pywhy_graphs/algorithms/pag.py index fb9f2cd3..547ef693 100644 --- a/pywhy_graphs/algorithms/pag.py +++ b/pywhy_graphs/algorithms/pag.py @@ -1106,7 +1106,9 @@ def _meek_rule4(graph: CPDAG, i: str, j: str) -> bool: def pag_to_mag(graph): - """Convert a PAG to an MAG using algorithm defined in Theorem 2 of + """Sample a MAG from a PAG using Zhang's algorithm. + + Using algorithm defined in Theorem 2 of defined in :footcite:`Zhang2008`. The algorithm turns all o-> edges to -> and -o edges to ->. Then converts the input graph into a DAG with no unshielded colliders. From 77fa079e914c55d17b1db4431b0c265aac1aac2f Mon Sep 17 00:00:00 2001 From: Aryan Roy <50577809+aryan26roy@users.noreply.github.com> Date: Mon, 30 Oct 2023 22:01:38 +0530 Subject: [PATCH 20/22] Update pywhy_graphs/algorithms/pag.py Co-authored-by: Adam Li Signed-off-by: Aryan Roy <50577809+aryan26roy@users.noreply.github.com> --- pywhy_graphs/algorithms/pag.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pywhy_graphs/algorithms/pag.py b/pywhy_graphs/algorithms/pag.py index 547ef693..b4ecfce2 100644 --- a/pywhy_graphs/algorithms/pag.py +++ b/pywhy_graphs/algorithms/pag.py @@ -1109,8 +1109,7 @@ def pag_to_mag(graph): """Sample a MAG from a PAG using Zhang's algorithm. Using algorithm defined in Theorem 2 of - defined in :footcite:`Zhang2008`. - The algorithm turns all o-> edges to -> and -o edges to ->. Then converts the + :footcite:`Zhang2008`, turns all o-> edges to -> and -o edges to ->. Then converts the input graph into a DAG with no unshielded colliders. Parameters From d335367f62a8c62763105d4738bb41a62e688414 Mon Sep 17 00:00:00 2001 From: Aryan Roy Date: Mon, 30 Oct 2023 22:40:29 +0530 Subject: [PATCH 21/22] Added comments to unit tests Signed-off-by: Aryan Roy --- pywhy_graphs/algorithms/pag.py | 4 ++-- pywhy_graphs/algorithms/tests/test_pag.py | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/pywhy_graphs/algorithms/pag.py b/pywhy_graphs/algorithms/pag.py index b4ecfce2..65bdbba5 100644 --- a/pywhy_graphs/algorithms/pag.py +++ b/pywhy_graphs/algorithms/pag.py @@ -1107,8 +1107,8 @@ def _meek_rule4(graph: CPDAG, i: str, j: str) -> bool: def pag_to_mag(graph): """Sample a MAG from a PAG using Zhang's algorithm. - - Using algorithm defined in Theorem 2 of + + Using algorithm defined in Theorem 2 of :footcite:`Zhang2008`, turns all o-> edges to -> and -o edges to ->. Then converts the input graph into a DAG with no unshielded colliders. diff --git a/pywhy_graphs/algorithms/tests/test_pag.py b/pywhy_graphs/algorithms/tests/test_pag.py index 0d522a65..1f0deb53 100644 --- a/pywhy_graphs/algorithms/tests/test_pag.py +++ b/pywhy_graphs/algorithms/tests/test_pag.py @@ -650,6 +650,10 @@ def test_pdst(pdst_graph): def test_pag_to_mag(): + + # C o- A o-> D <-o B + # B o-o A o-o C o-> D + pag = PAG() pag.add_edge("A", "D", pag.directed_edge_name) pag.add_edge("A", "C", pag.circle_edge_name) @@ -664,6 +668,9 @@ def test_pag_to_mag(): out_mag = pywhy_graphs.pag_to_mag(pag) + # C <- A -> B -> D or C -> A -> B -> D or C <- A <- B -> D + # A -> D <- C + assert ( ((out_mag.has_edge("A", "B")) or (out_mag.has_edge("B", "A"))) and ((out_mag.has_edge("A", "C")) or (out_mag.has_edge("C", "A"))) @@ -672,6 +679,8 @@ def test_pag_to_mag(): and (out_mag.has_edge("C", "D")) ) + # D o-> A <-o B + # D o-o B pag = PAG() pag.add_edge("A", "B", pag.circle_edge_name) pag.add_edge("B", "A", pag.directed_edge_name) @@ -682,12 +691,18 @@ def test_pag_to_mag(): out_mag = pywhy_graphs.pag_to_mag(pag) + # B -> A <- D + # D -> B or D <- B + assert ( out_mag.has_edge("B", "A") and out_mag.has_edge("D", "A") and (out_mag.has_edge("D", "B") or out_mag.has_edge("B", "D")) ) + # A -> B <- C o-o D + # D o-o E -> B + pag = PAG() pag.add_edge("A", "B", pag.directed_edge_name) pag.add_edge("C", "B", pag.directed_edge_name) @@ -699,6 +714,9 @@ def test_pag_to_mag(): out_mag = pywhy_graphs.pag_to_mag(pag) + # A -> B <- C <- D or A -> B <- C -> D + # D <- E -> B or D <- E -> B + assert ( out_mag.has_edge("A", "B") and out_mag.has_edge("C", "B") From 194f66369d0d092cf42cff336cd5562967a59bde Mon Sep 17 00:00:00 2001 From: Aryan Roy Date: Mon, 30 Oct 2023 22:42:48 +0530 Subject: [PATCH 22/22] Cleaned up the docstring Signed-off-by: Aryan Roy --- pywhy_graphs/algorithms/pag.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pywhy_graphs/algorithms/pag.py b/pywhy_graphs/algorithms/pag.py index 65bdbba5..d6c69e45 100644 --- a/pywhy_graphs/algorithms/pag.py +++ b/pywhy_graphs/algorithms/pag.py @@ -1108,9 +1108,9 @@ def _meek_rule4(graph: CPDAG, i: str, j: str) -> bool: def pag_to_mag(graph): """Sample a MAG from a PAG using Zhang's algorithm. - Using algorithm defined in Theorem 2 of - :footcite:`Zhang2008`, turns all o-> edges to -> and -o edges to ->. Then converts the - input graph into a DAG with no unshielded colliders. + Using the algorithm defined in Theorem 2 of :footcite:`Zhang2008`, which turns all + o-> edges to -> and -o edges to ->, then it converts the graph into a DAG with + no unshielded colliders using the meek rules. Parameters ----------