Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add the ability to convert a PAG to MAG #93

Merged
merged 24 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/whats_new/v0.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Changelog
---------
- |Feature| Implement and test functions to validate an MAG and check the presence of almost directed cycles, by `Aryan Roy`_ (:pr:`91`)
- |Feature| Implement and test functions to convert a DAG to MAG, by `Aryan Roy`_ (:pr:`96`)
- |Feature| Implement and test functions to convert a PAG to MAG, by `Aryan Roy`_ (:pr:`93`)

Code and Documentation Contributors
-----------------------------------
Expand Down
272 changes: 270 additions & 2 deletions pywhy_graphs/algorithms/pag.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import logging
from collections import deque
from itertools import chain
from itertools import chain, combinations, permutations
from typing import List, Optional, Set, Tuple

import networkx as nx
import numpy as np

from pywhy_graphs import PAG, StationaryTimeSeriesPAG
from pywhy_graphs import ADMG, CPDAG, PAG, StationaryTimeSeriesPAG
from pywhy_graphs.algorithms.generic import single_source_shortest_mixed_path
from pywhy_graphs.typing import Node, TsNode

Expand All @@ -22,6 +22,7 @@
"pds_t",
"pds_t_path",
"is_definite_noncollider",
"pag_to_mag",
]


Expand Down Expand Up @@ -908,3 +909,270 @@
)
if node[1] > 0:
raise ValueError(f"All lag points should be 0, or less. You passed in {node}.")


def _apply_meek_rules(graph: CPDAG) -> None:
"""Orient edges in a skeleton graph to estimate the causal DAG, or CPDAG.
These are known as the Meek rules :footcite:`Meek1995`. They are deterministic
in the sense that they are logical characterizations of what edges must be
present given the rest of the local graph structure.
Parameters
----------
graph : CPDAG
A graph containing directed and undirected edges.
"""
# For all the combination of nodes i and j, apply the following
# rules.
completed = False
while not completed: # type: ignore
change_flag = False
for i in graph.nodes:
for j in graph.neighbors(i):
if i == j:
continue

Check warning on line 932 in pywhy_graphs/algorithms/pag.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/pag.py#L932

Added line #L932 was not covered by tests
# Rule 1: Orient i-j into i->j whenever there is an arrow k->i
# such that k and j are nonadjacent.
r1_add = _meek_rule1(graph, i, j)

# Rule 2: Orient i-j into i->j whenever there is a chain
# i->k->j.
r2_add = _meek_rule2(graph, i, j)

# Rule 3: Orient i-j into i->j whenever there are two chains
# i-k->j and i-l->j such that k and l are nonadjacent.
r3_add = _meek_rule3(graph, i, j)

# Rule 4: Orient i-j into i->j whenever there are two chains
# i-k->l and k->l->j such that k and j are nonadjacent.
#
r4_add = _meek_rule4(graph, i, j)

if any([r1_add, r2_add, r3_add, r4_add]) and not change_flag:
change_flag = True

Check warning on line 951 in pywhy_graphs/algorithms/pag.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/pag.py#L951

Added line #L951 was not covered by tests
if not change_flag:
completed = True
break


def _meek_rule1(graph: CPDAG, i: str, j: str) -> bool:
"""Apply rule 1 of Meek's rules.
Looks for i - j such that k -> i, such that (k,i,j)
is an unshielded triple. Then can orient i - j as i -> j.
"""
added_arrows = False

# Check if i-j.
if graph.has_edge(i, j, graph.undirected_edge_name):
for k in graph.predecessors(i):
# Skip if k and j are adjacent because then it is a
# shielded triple
if j in graph.neighbors(k):
continue

Check warning on line 970 in pywhy_graphs/algorithms/pag.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/pag.py#L970

Added line #L970 was not covered by tests

# check if the triple is in the graph's excluded triples
if frozenset((k, i, j)) in graph.excluded_triples:
continue

Check warning on line 974 in pywhy_graphs/algorithms/pag.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/pag.py#L974

Added line #L974 was not covered by tests

# Make i-j into i->j
graph.orient_uncertain_edge(i, j)

Check warning on line 977 in pywhy_graphs/algorithms/pag.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/pag.py#L977

Added line #L977 was not covered by tests

added_arrows = True
break

Check warning on line 980 in pywhy_graphs/algorithms/pag.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/pag.py#L979-L980

Added lines #L979 - L980 were not covered by tests
return added_arrows


def _meek_rule2(graph: CPDAG, i: str, j: str) -> bool:
"""Apply rule 2 of Meek's rules.
Check for i - j, and then looks for i -> k -> j
triple, to orient i - j as i -> j.
"""
added_arrows = False

# Check if i-j.
if graph.has_edge(i, j, graph.undirected_edge_name):
# Find nodes k where k is i->k
child_i = set()
for k in graph.successors(i):
if not graph.has_edge(k, i, graph.directed_edge_name):
child_i.add(k)
# Find nodes j where j is k->j.
parent_j = set()
for k in graph.predecessors(j):
if not graph.has_edge(j, k, graph.directed_edge_name):
parent_j.add(k)

Check warning on line 1002 in pywhy_graphs/algorithms/pag.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/pag.py#L1002

Added line #L1002 was not covered by tests

# Check if there is any node k where i->k->j.
candidate_k = child_i.intersection(parent_j)
# if the graph has excluded triples, we would check at this point
if graph.excluded_triples:
# check if the triple is in the graph's excluded triples
# if so, remove them from the candidates
for k in candidate_k:
if frozenset((i, k, j)) in graph.excluded_triples:
candidate_k.remove(k)

Check warning on line 1012 in pywhy_graphs/algorithms/pag.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/pag.py#L1012

Added line #L1012 was not covered by tests

# if there are candidate 'k' nodes, then orient the edge accordingly
if len(candidate_k) > 0:
# Make i-j into i->j
graph.orient_uncertain_edge(i, j)
added_arrows = True

Check warning on line 1018 in pywhy_graphs/algorithms/pag.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/pag.py#L1017-L1018

Added lines #L1017 - L1018 were not covered by tests
return added_arrows


def _meek_rule3(graph: CPDAG, i: str, j: str) -> bool:
"""Apply rule 3 of Meek's rules.
Check for i - j, and then looks for k -> j <- l
collider, and i - k and i - l, then orient i -> j.
"""
added_arrows = False

# Check if i-j first
if graph.has_edge(i, j, graph.undirected_edge_name):
# For all the pairs of nodes adjacent to i,
# look for (k, l), such that j -> l and k -> l
for (k, l) in combinations(graph.neighbors(i), 2):
# Skip if k and l are adjacent.
if l in graph.neighbors(k):
continue

Check warning on line 1036 in pywhy_graphs/algorithms/pag.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/pag.py#L1036

Added line #L1036 was not covered by tests
# Skip if not k->j.
if graph.has_edge(j, k, graph.directed_edge_name) or (
not graph.has_edge(k, j, graph.directed_edge_name)
):
continue
# Skip if not l->j.
if graph.has_edge(j, l, graph.directed_edge_name) or (
not graph.has_edge(l, j, graph.directed_edge_name)
):
continue

Check warning on line 1046 in pywhy_graphs/algorithms/pag.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/pag.py#L1046

Added line #L1046 was not covered by tests

# check if the triple is inside graph's excluded triples
if frozenset((l, i, k)) in graph.excluded_triples:
continue

Check warning on line 1050 in pywhy_graphs/algorithms/pag.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/pag.py#L1050

Added line #L1050 was not covered by tests

# if i - k and i - l, then at this point, we have a valid path
# to orient
if graph.has_edge(k, i, graph.undirected_edge_name) and graph.has_edge(
l, i, graph.undirected_edge_name
):
graph.orient_uncertain_edge(i, j)
added_arrows = True
break

Check warning on line 1059 in pywhy_graphs/algorithms/pag.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/pag.py#L1057-L1059

Added lines #L1057 - L1059 were not covered by tests
return added_arrows


def _meek_rule4(graph: CPDAG, i: str, j: str) -> bool:
"""Apply rule 4 of Meek's rules.
Check for i - j, and then looks for i - k -> l -> j, to orient i - j as i -> j.
"""
added_arrows = False

# Check if i-j.
if graph.has_edge(i, j, graph.undirected_edge_name):
# Find nodes k where k is i-k
adj_i = set()
for k in graph.neighbors(i):
if not graph.has_edge(k, i, graph.directed_edge_name):
adj_i.add(k)

# Find nodes l where j is l->j.
parent_j = set()
for k in graph.predecessors(j):
if not graph.has_edge(j, k, graph.directed_edge_name):
parent_j.add(k)

Check warning on line 1081 in pywhy_graphs/algorithms/pag.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/pag.py#L1081

Added line #L1081 was not covered by tests

# generate all permutations of sets containing neighbors of i and parents of j
permut = permutations(adj_i, len(parent_j))
unq = set() # type: ignore
for comb in permut:
zipped = zip(comb, parent_j)
unq.update(zipped)

# check if these pairs have a directed edge between them and that k-j does not exist
dedges = set(graph.directed_edges)
undedges = set(graph.undirected_edges)
candidate_k = set()
for pair in unq:
if pair in dedges:
if (pair[0], j) not in undedges:
candidate_k.add(pair)

Check warning on line 1097 in pywhy_graphs/algorithms/pag.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/pag.py#L1097

Added line #L1097 was not covered by tests

# if there are candidate 'k->l' pairs, then orient the edge accordingly
if len(candidate_k) > 0:
# Make i-j into i->j
# logger.info(f"R2: Removing edge {i}-{j} to form {i}->{j}.")
graph.orient_uncertain_edge(i, j)
added_arrows = True

Check warning on line 1104 in pywhy_graphs/algorithms/pag.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/pag.py#L1103-L1104

Added lines #L1103 - L1104 were not covered by tests
return added_arrows


def pag_to_mag(graph):
"""Sample a MAG from a PAG using Zhang's algorithm.

Using the algorithm defined in Theorem 2 of :footcite:`Zhang2008`, which turns all
o-> edges to -> and -o edges to ->, then it converts the graph into a DAG with
no unshielded colliders using the meek rules.

Parameters
----------
G : Graph
The PAG.

Returns
-------
mag : Graph
The MAG constructed from the PAG.
"""
copy_graph = graph.copy()

cedges = set(copy_graph.circle_edges)
dedges = set(copy_graph.directed_edges)

temp_cpdag = CPDAG()

to_remove = []
to_reorient = []
to_add = []

for u, v in cedges:
if (v, u) in dedges: # remove the circle end from a 'o-->' edge to make a '-->' edge
to_remove.append((u, v))
elif (v, u) not in cedges: # reorient a '--o' edge to '-->'
to_reorient.append((u, v))

Check warning on line 1140 in pywhy_graphs/algorithms/pag.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/pag.py#L1140

Added line #L1140 was not covered by tests
elif (v, u) in cedges and (
v,
u,
) not in to_add: # add all 'o--o' edges to the cpdag
to_add.append((u, v))
for u, v in to_remove:
copy_graph.remove_edge(u, v, copy_graph.circle_edge_name)
for u, v in to_reorient:
copy_graph.orient_uncertain_edge(u, v)

Check warning on line 1149 in pywhy_graphs/algorithms/pag.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/pag.py#L1149

Added line #L1149 was not covered by tests
for u, v in to_add:
temp_cpdag.add_edge(v, u, temp_cpdag.undirected_edge_name)

flag = True

# convert the graph into a DAG with no unshielded colliders

while flag:
undedges = temp_cpdag.undirected_edges
if len(undedges) != 0:
for (u, v) in undedges:
temp_cpdag.remove_edge(u, v, temp_cpdag.undirected_edge_name)
temp_cpdag.add_edge(u, v, temp_cpdag.directed_edge_name)
_apply_meek_rules(temp_cpdag)
break
else:
flag = False

mag = ADMG() # provisional MAG

# construct the final MAG

for (u, v) in copy_graph.directed_edges:
mag.add_edge(u, v, mag.directed_edge_name)

for (u, v) in temp_cpdag.directed_edges:
mag.add_edge(u, v, mag.directed_edge_name)

return mag
77 changes: 77 additions & 0 deletions pywhy_graphs/algorithms/tests/test_pag.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,3 +647,80 @@ def test_pdst(pdst_graph):
ex_pdsep_t = pds_t(G, ("E", 0), ("x", -1))
assert ("y", -2) not in xe_pdsep_t
assert ("y", -2) not in ex_pdsep_t


def test_pag_to_mag():

# C o- A o-> D <-o B
# B o-o A o-o C o-> D

pag = PAG()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As before, for unit-tests, these are commonly read by other developers, so it's useful to add some comments specifying what the test is actually testing for vs other tests. You can even split this into multiple unit-test functions w/ more descriptive naming if that's easier.

For example: test_pag_to_mag_when_graph_is_confounded_and...

pag.add_edge("A", "D", pag.directed_edge_name)
pag.add_edge("A", "C", pag.circle_edge_name)
pag.add_edge("D", "A", pag.circle_edge_name)
pag.add_edge("B", "D", pag.directed_edge_name)
pag.add_edge("C", "D", pag.directed_edge_name)
pag.add_edge("D", "B", pag.circle_edge_name)
pag.add_edge("D", "C", pag.circle_edge_name)
pag.add_edge("C", "A", pag.circle_edge_name)
pag.add_edge("B", "A", pag.circle_edge_name)
pag.add_edge("A", "B", pag.circle_edge_name)

out_mag = pywhy_graphs.pag_to_mag(pag)

# C <- A -> B -> D or C -> A -> B -> D or C <- A <- B -> D
# A -> D <- C

assert (
((out_mag.has_edge("A", "B")) or (out_mag.has_edge("B", "A")))
and ((out_mag.has_edge("A", "C")) or (out_mag.has_edge("C", "A")))
and (out_mag.has_edge("A", "D"))
and (out_mag.has_edge("B", "D"))
and (out_mag.has_edge("C", "D"))
)

# D o-> A <-o B
# D o-o B
pag = PAG()
pag.add_edge("A", "B", pag.circle_edge_name)
pag.add_edge("B", "A", pag.directed_edge_name)
pag.add_edge("D", "A", pag.directed_edge_name)
pag.add_edge("A", "D", pag.circle_edge_name)
pag.add_edge("D", "B", pag.circle_edge_name)
pag.add_edge("B", "D", pag.circle_edge_name)

out_mag = pywhy_graphs.pag_to_mag(pag)

# B -> A <- D
# D -> B or D <- B

assert (
out_mag.has_edge("B", "A")
and out_mag.has_edge("D", "A")
and (out_mag.has_edge("D", "B") or out_mag.has_edge("B", "D"))
)

# A -> B <- C o-o D
# D o-o E -> B

pag = PAG()
pag.add_edge("A", "B", pag.directed_edge_name)
pag.add_edge("C", "B", pag.directed_edge_name)
pag.add_edge("E", "B", pag.directed_edge_name)
pag.add_edge("E", "D", pag.circle_edge_name)
pag.add_edge("C", "D", pag.circle_edge_name)
pag.add_edge("D", "E", pag.circle_edge_name)
pag.add_edge("D", "C", pag.circle_edge_name)

out_mag = pywhy_graphs.pag_to_mag(pag)

# A -> B <- C <- D or A -> B <- C -> D
# D <- E -> B or D <- E -> B

assert (
out_mag.has_edge("A", "B")
and out_mag.has_edge("C", "B")
and out_mag.has_edge("E", "B")
and (out_mag.has_edge("E", "D") or out_mag.has_edge("D", "E"))
and (out_mag.has_edge("D", "C") or out_mag.has_edge("C", "D"))
)
Loading