Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add the ability to convert a DAG to an MAG #96

Merged
merged 28 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
9221f64
Added function signature
aryan26roy Sep 8, 2023
74b6612
Merge branch 'main' into aryan_dag_to_mag
aryan26roy Sep 8, 2023
03eb9cb
fixed the return statement
aryan26roy Sep 8, 2023
65ff956
Rebased onto main
aryan26roy Sep 9, 2023
1e42a20
testing
aryan26roy Sep 9, 2023
2b84d74
Bump actions/checkout from 3 to 4 (#94)
dependabot[bot] Sep 6, 2023
f0f25d4
Added function signature and rebased
aryan26roy Sep 8, 2023
f293c99
Resolved merge issue
aryan26roy Sep 9, 2023
5df8578
Wrote down the algorithm
aryan26roy Sep 9, 2023
7fd8c93
Added partial implementation
aryan26roy Sep 9, 2023
e5f16bb
Added page number
aryan26roy Sep 11, 2023
e247bd7
Completed implementation
aryan26roy Sep 12, 2023
b8b69d2
Added some tests and improved the implementation
aryan26roy Sep 18, 2023
f6656d4
Added two working tests
aryan26roy Sep 21, 2023
7832e88
Added another test
aryan26roy Sep 21, 2023
302733a
Linting
aryan26roy Sep 21, 2023
916743a
Changelog and linting
aryan26roy Sep 21, 2023
4e9b7e5
Linting
aryan26roy Sep 21, 2023
f7eb0b8
implemented reccomendations on code readabiliy
aryan26roy Sep 21, 2023
7e30a76
Linting
aryan26roy Sep 21, 2023
4daf699
Update pywhy_graphs/algorithms/generic.py
aryan26roy Sep 23, 2023
822e355
Update pywhy_graphs/algorithms/generic.py
aryan26roy Sep 23, 2023
4657c72
Added adc check to each MAG
aryan26roy Sep 23, 2023
29ac9cf
Added tests for inducing path edge case
aryan26roy Sep 23, 2023
f8865e6
Added checks for maximality in tests
aryan26roy Sep 26, 2023
58723ed
improved memoisation
aryan26roy Sep 26, 2023
d5fb5b9
Using frozenset
aryan26roy Sep 26, 2023
6cf1edf
Added test for is_maximal
aryan26roy Sep 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ causal graph operations.
.. autosummary::
:toctree: generated/

dag_to_mag
valid_mag
has_adc
inducing_path
Expand Down
1 change: 1 addition & 0 deletions doc/whats_new/v0.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Version 0.2
Changelog
---------
- |Feature| Implement and test functions to validate an MAG and check the presence of almost directed cycles, by `Aryan Roy`_ (:pr:`91`)
- |Feature| Implement and test functions to convert a DAG to MAG, by `Aryan Roy`_ (:pr:`96`)

Code and Documentation Contributors
-----------------------------------
Expand Down
120 changes: 120 additions & 0 deletions pywhy_graphs/algorithms/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"inducing_path",
"has_adc",
"valid_mag",
"dag_to_mag",
"is_maximal",
]


Expand Down Expand Up @@ -554,26 +556,29 @@
.. footbibliography::
"""
if L is None:
L = set()

Check warning on line 559 in pywhy_graphs/algorithms/generic.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/generic.py#L559

Added line #L559 was not covered by tests

if S is None:
S = set()

Check warning on line 562 in pywhy_graphs/algorithms/generic.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/generic.py#L562

Added line #L562 was not covered by tests

nodes = set(G.nodes)

if node_x not in nodes or node_y not in nodes:
raise ValueError("The provided nodes are not in the graph.")

Check warning on line 567 in pywhy_graphs/algorithms/generic.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/generic.py#L567

Added line #L567 was not covered by tests

if node_x == node_y:
raise ValueError("The source and destination nodes are the same.")

Check warning on line 570 in pywhy_graphs/algorithms/generic.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/generic.py#L570

Added line #L570 was not covered by tests

if (node_x in L) or (node_y in L) or (node_x in S) or (node_y in S):
return (False, [])

edges = G.edges()

# XXX: fix this when graphs are refactored to only check for directed/bidirected edge types
for elem in edges.keys():
if elem not in {"directed", "bidirected"}:
if len(edges[elem]) != 0:
raise ValueError("Inducing Path is not defined for this graph.")

Check warning on line 581 in pywhy_graphs/algorithms/generic.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/generic.py#L581

Added line #L581 was not covered by tests

path = [] # this will contain the path.

Expand Down Expand Up @@ -703,3 +708,118 @@
return False

return True


def dag_to_mag(G, L: Set = None, S: Set = None):
"""Converts a DAG to a valid MAG.

The algorithm is defined in :footcite:`Zhang2008` on page 1877.
aryan26roy marked this conversation as resolved.
Show resolved Hide resolved

Parameters:
-----------
G : Graph
The graph.
L : Set
Nodes that are ignored on the path. Defaults to an empty set.
S : Set
Nodes that are always conditioned on. Defaults to an empty set.

Returns
-------
mag : Graph
The MAG.
"""

if L is None:
L = set()

if S is None:
S = set()
aryan26roy marked this conversation as resolved.
Show resolved Hide resolved

# for each pair of nodes find if they have an inducing path between them.
# only then will they be adjacent in the MAG.

all_nodes = set(G.nodes)
adj_nodes = []

for source in all_nodes:
copy_all = all_nodes.copy()
copy_all.remove(source)
for dest in copy_all:
out = inducing_path(G, source, dest, L, S)
if out[0] is True and {source, dest} not in adj_nodes:
adj_nodes.append({source, dest})

# find the ancestors of B U S (ansB) and A U S (ansA) for each pair of adjacent nodes

mag = ADMG()

for A, B in adj_nodes:

AuS = S.union(A)
BuS = S.union(B)

ansA: Set = set()
ansB: Set = set()

for node in AuS:
ansA = ansA.union(_directed_sub_graph_ancestors(G, node))

for node in BuS:
ansB = ansB.union(_directed_sub_graph_ancestors(G, node))

if A in ansB and B not in ansA:
# if A is in ansB and B is not in ansA, A -> B
mag.add_edge(A, B, mag.directed_edge_name)

elif A not in ansB and B in ansA:
# if B is in ansA and A is not in ansB, A <- B
mag.add_edge(B, A, mag.directed_edge_name)

elif A not in ansB and B not in ansA:
# if A is not in ansB and B is not in ansA, A <-> B
mag.add_edge(B, A, mag.bidirected_edge_name)

elif A in ansB and B in ansA:
# if A is in ansB and B is in ansA, A - B
mag.add_edge(B, A, mag.undirected_edge_name)

return mag


def is_maximal(G, L: Set = None, S: Set = None):
"""Checks to see if the graph is maximal.

Parameters:
-----------
G : Graph
The graph.

Returns
-------
is_maximal : bool
A boolean indicating whether the provided graph is maximal or not.
"""

if L is None:
L = set()

if S is None:
S = set()

all_nodes = set(G.nodes)
checked = set()
for source in all_nodes:
nb = set(G.neighbors(source))
cur_set = all_nodes - nb
cur_set.remove(source)
for dest in cur_set:
current_pair = frozenset({source, dest})
if current_pair not in checked:
checked.add(current_pair)
out = inducing_path(G, source, dest, L, S)
if out[0] is True:
return False
aryan26roy marked this conversation as resolved.
Show resolved Hide resolved
else:
continue
return True
124 changes: 124 additions & 0 deletions pywhy_graphs/algorithms/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,30 @@ def test_inducing_path_corner_cases():

assert pywhy_graphs.inducing_path(admg, "X", "Y", L, S)[0]

# X -> Z <- Y, A <- B <- Z
admg = ADMG()
admg.add_edge("X", "Z", admg.directed_edge_name)
admg.add_edge("Y", "Z", admg.directed_edge_name)
admg.add_edge("Z", "B", admg.directed_edge_name)
admg.add_edge("B", "A", admg.directed_edge_name)

L = {"X"}
S = {"A"}

assert not pywhy_graphs.inducing_path(admg, "X", "Y", L, S)[0]

# X -> Z <- Y, A <- B <- Z
admg = ADMG()
admg.add_edge("X", "Z", admg.directed_edge_name)
admg.add_edge("Y", "Z", admg.directed_edge_name)
admg.add_edge("Z", "B", admg.directed_edge_name)
admg.add_edge("B", "A", admg.directed_edge_name)

L = {}
S = {"A", "Y"}

assert not pywhy_graphs.inducing_path(admg, "X", "Y", L, S)[0]


def test_is_collider():
# Z -> X -> A <- B -> Y; H -> A
Expand Down Expand Up @@ -348,3 +372,103 @@ def test_valid_mag():
admg.add_edge("H", "J", admg.undirected_edge_name)

assert not pywhy_graphs.valid_mag(admg) # there is an undirected edge between H and J


def test_dag_to_mag():
aryan26roy marked this conversation as resolved.
Show resolved Hide resolved
aryan26roy marked this conversation as resolved.
Show resolved Hide resolved

# A -> E -> S
# H -> E , H -> R
admg = ADMG()
admg.add_edge("A", "E", admg.directed_edge_name)
admg.add_edge("E", "S", admg.directed_edge_name)
admg.add_edge("H", "E", admg.directed_edge_name)
admg.add_edge("H", "R", admg.directed_edge_name)

S = {"S"}
L = {"H"}

out_mag = pywhy_graphs.dag_to_mag(admg, L, S)
assert pywhy_graphs.is_maximal(out_mag)
assert not pywhy_graphs.has_adc(out_mag)
out_edges = out_mag.edges()
dir_edges = list(out_edges["directed"])
assert (
("A", "R") in out_edges["directed"]
and ("E", "R") in out_edges["directed"]
and len(out_edges["directed"]) == 2
)
assert ("A", "E") in out_edges["undirected"]

out_mag = pywhy_graphs.dag_to_mag(admg)
dir_edges = list(out_mag.edges()["directed"])

assert (
("A", "E") in dir_edges
and ("E", "S") in dir_edges
and ("H", "E") in dir_edges
and ("H", "R") in dir_edges
)

# A -> E -> S <- H
# H -> E , H -> R,

admg = ADMG()
admg.add_edge("A", "E", admg.directed_edge_name)
admg.add_edge("H", "S", admg.directed_edge_name)
admg.add_edge("H", "E", admg.directed_edge_name)
admg.add_edge("H", "R", admg.directed_edge_name)

S = {"S"}
L = {"H"}

out_mag = pywhy_graphs.dag_to_mag(admg, L, S)
assert pywhy_graphs.is_maximal(out_mag)
assert not pywhy_graphs.has_adc(out_mag)
out_edges = out_mag.edges()

dir_edges = list(out_edges["directed"])
assert ("A", "E") in out_edges["directed"] and len(out_edges["directed"]) == 1
assert ("E", "R") in out_edges["bidirected"]

# P -> S -> L <- G
# G -> S -> I <- J
# J -> S

admg = ADMG()
admg.add_edge("P", "S", admg.directed_edge_name)
admg.add_edge("S", "L", admg.directed_edge_name)
admg.add_edge("G", "S", admg.directed_edge_name)
admg.add_edge("G", "L", admg.directed_edge_name)
admg.add_edge("I", "S", admg.directed_edge_name)
admg.add_edge("J", "I", admg.directed_edge_name)
admg.add_edge("J", "S", admg.directed_edge_name)

S = set()
L = {"J"}

out_mag = pywhy_graphs.dag_to_mag(admg, L, S)
assert pywhy_graphs.is_maximal(out_mag)
assert not pywhy_graphs.has_adc(out_mag)
out_edges = out_mag.edges()
dir_edges = list(out_edges["directed"])
assert (
("G", "S") in dir_edges
and ("G", "L") in dir_edges
and ("S", "L") in dir_edges
and ("I", "S") in dir_edges
and ("P", "S") in dir_edges
and len(dir_edges) == 5
)


def test_is_maximal():
# X <- Y <-> Z <-> H; Z -> X
admg = ADMG()
admg.add_edge("Y", "X", admg.directed_edge_name)
admg.add_edge("Z", "X", admg.directed_edge_name)
admg.add_edge("Z", "Y", admg.bidirected_edge_name)
admg.add_edge("Z", "H", admg.bidirected_edge_name)

S = {}
L = {"Y"}
assert not pywhy_graphs.is_maximal(admg, L, S)
Loading