Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add the ability to convert a DAG to an MAG #96

Merged
merged 28 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
9221f64
Added function signature
aryan26roy Sep 8, 2023
74b6612
Merge branch 'main' into aryan_dag_to_mag
aryan26roy Sep 8, 2023
03eb9cb
fixed the return statement
aryan26roy Sep 8, 2023
65ff956
Rebased onto main
aryan26roy Sep 9, 2023
1e42a20
testing
aryan26roy Sep 9, 2023
2b84d74
Bump actions/checkout from 3 to 4 (#94)
dependabot[bot] Sep 6, 2023
f0f25d4
Added function signature and rebased
aryan26roy Sep 8, 2023
f293c99
Resolved merge issue
aryan26roy Sep 9, 2023
5df8578
Wrote down the algorithm
aryan26roy Sep 9, 2023
7fd8c93
Added partial implementation
aryan26roy Sep 9, 2023
e5f16bb
Added page number
aryan26roy Sep 11, 2023
e247bd7
Completed implementation
aryan26roy Sep 12, 2023
b8b69d2
Added some tests and improved the implementation
aryan26roy Sep 18, 2023
f6656d4
Added two working tests
aryan26roy Sep 21, 2023
7832e88
Added another test
aryan26roy Sep 21, 2023
302733a
Linting
aryan26roy Sep 21, 2023
916743a
Changelog and linting
aryan26roy Sep 21, 2023
4e9b7e5
Linting
aryan26roy Sep 21, 2023
f7eb0b8
implemented reccomendations on code readabiliy
aryan26roy Sep 21, 2023
7e30a76
Linting
aryan26roy Sep 21, 2023
4daf699
Update pywhy_graphs/algorithms/generic.py
aryan26roy Sep 23, 2023
822e355
Update pywhy_graphs/algorithms/generic.py
aryan26roy Sep 23, 2023
4657c72
Added adc check to each MAG
aryan26roy Sep 23, 2023
29ac9cf
Added tests for inducing path edge case
aryan26roy Sep 23, 2023
f8865e6
Added checks for maximality in tests
aryan26roy Sep 26, 2023
58723ed
improved memoisation
aryan26roy Sep 26, 2023
d5fb5b9
Using frozenset
aryan26roy Sep 26, 2023
6cf1edf
Added test for is_maximal
aryan26roy Sep 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ causal graph operations.
.. autosummary::
:toctree: generated/

dag_to_mag
valid_mag
has_adc
inducing_path
Expand Down
1 change: 1 addition & 0 deletions doc/whats_new/v0.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Version 0.2
Changelog
---------
- |Feature| Implement and test functions to validate an MAG and check the presence of almost directed cycles, by `Aryan Roy`_ (:pr:`91`)
- |Feature| Implement and test functions to convert a DAG to MAG, by `Aryan Roy`_ (:pr:`96`)

Code and Documentation Contributors
-----------------------------------
Expand Down
78 changes: 78 additions & 0 deletions pywhy_graphs/algorithms/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"inducing_path",
"has_adc",
"valid_mag",
"dag_to_mag",
]


Expand Down Expand Up @@ -554,26 +555,29 @@
.. footbibliography::
"""
if L is None:
L = set()

Check warning on line 558 in pywhy_graphs/algorithms/generic.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/generic.py#L558

Added line #L558 was not covered by tests

if S is None:
S = set()

Check warning on line 561 in pywhy_graphs/algorithms/generic.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/generic.py#L561

Added line #L561 was not covered by tests

nodes = set(G.nodes)

if node_x not in nodes or node_y not in nodes:
raise ValueError("The provided nodes are not in the graph.")

Check warning on line 566 in pywhy_graphs/algorithms/generic.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/generic.py#L566

Added line #L566 was not covered by tests

if node_x == node_y:
raise ValueError("The source and destination nodes are the same.")

Check warning on line 569 in pywhy_graphs/algorithms/generic.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/generic.py#L569

Added line #L569 was not covered by tests

if (node_x in L) or (node_y in L) or (node_x in S) or (node_y in S):
return (False, [])

edges = G.edges()

# XXX: fix this when graphs are refactored to only check for directed/bidirected edge types
for elem in edges.keys():
if elem not in {"directed", "bidirected"}:
if len(edges[elem]) != 0:
raise ValueError("Inducing Path is not defined for this graph.")

Check warning on line 580 in pywhy_graphs/algorithms/generic.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/generic.py#L580

Added line #L580 was not covered by tests

path = [] # this will contain the path.

Expand Down Expand Up @@ -703,3 +707,77 @@
return False

return True


def dag_to_mag(G, L: Set = None, S: Set = None):
"""Converts a DAG to a valid MAG.
The algorithm is defined in :footcite:`Zhang2008` on page 1877.
aryan26roy marked this conversation as resolved.
Show resolved Hide resolved

Parameters:
-----------
G : Graph
The graph.
L : Set
Nodes that are ignored on the path. Defaults to an empty set.
S : Set
Nodes that are always conditioned on. Defaults to an empty set.
"""

if L is None:
L = set()

Check warning on line 727 in pywhy_graphs/algorithms/generic.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/generic.py#L727

Added line #L727 was not covered by tests

if S is None:
S = set()

Check warning on line 730 in pywhy_graphs/algorithms/generic.py

View check run for this annotation

Codecov / codecov/patch

pywhy_graphs/algorithms/generic.py#L730

Added line #L730 was not covered by tests
aryan26roy marked this conversation as resolved.
Show resolved Hide resolved

# for each pair of nodes find if they have an inducing path between them.
# only then will they be adjacent in the MAG.

all_nodes = set(G.nodes)
adj_nodes = []

for source in all_nodes:
copy_all = all_nodes.copy()
copy_all.remove(source)
for dest in copy_all:
out = inducing_path(G, source, dest, L, S)
if out[0] is True and {source, dest} not in adj_nodes:
adj_nodes.append({source, dest})

# find the ancestors of B U S (ansB) and A U S (ansA) for each pair of adjacent nodes

mag = ADMG()

for elem in adj_nodes:

temp_list = list(elem)
a = set(temp_list[0])
b = set(temp_list[1])
aryan26roy marked this conversation as resolved.
Show resolved Hide resolved
aus = S.union(a)
bus = S.union(b)
aryan26roy marked this conversation as resolved.
Show resolved Hide resolved

ansA: Set[str] = set()
ansB: Set[str] = set()
aryan26roy marked this conversation as resolved.
Show resolved Hide resolved

for node in aus:
ansA = ansA.union(_directed_sub_graph_ancestors(G, node))

for node in bus:
ansB = ansB.union(_directed_sub_graph_ancestors(G, node))

if temp_list[0] in ansB and temp_list[1] not in ansA:
# if A is in ansB and B is not in ansA, A -> B
mag.add_edge(temp_list[0], temp_list[1], mag.directed_edge_name)

elif temp_list[0] not in ansB and temp_list[1] in ansA:
# if B is in ansA and A is not in ansB, A <- B
mag.add_edge(temp_list[1], temp_list[0], mag.directed_edge_name)

elif temp_list[0] not in ansB and temp_list[1] not in ansA:
# if A is not in ansB and B is not in ansA, A <-> B
mag.add_edge(temp_list[1], temp_list[0], mag.bidirected_edge_name)

elif temp_list[0] in ansB and temp_list[1] in ansA:
# if A is in ansB and B is in ansA, A - B
mag.add_edge(temp_list[1], temp_list[0], mag.undirected_edge_name)

return mag
60 changes: 60 additions & 0 deletions pywhy_graphs/algorithms/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,63 @@ def test_valid_mag():
admg.add_edge("H", "J", admg.undirected_edge_name)

assert not pywhy_graphs.valid_mag(admg) # there is an undirected edge between H and J


def test_dag_to_mag():
aryan26roy marked this conversation as resolved.
Show resolved Hide resolved
aryan26roy marked this conversation as resolved.
Show resolved Hide resolved

admg = ADMG()
admg.add_edge("A", "E", admg.directed_edge_name)
admg.add_edge("E", "S", admg.directed_edge_name)
admg.add_edge("H", "E", admg.directed_edge_name)
admg.add_edge("H", "R", admg.directed_edge_name)

S = {"S"}
L = {"H"}

out_edges = pywhy_graphs.dag_to_mag(admg, L, S).edges()

dir_edges = list(out_edges["directed"])
assert (
("A", "R") in out_edges["directed"]
and ("E", "R") in out_edges["directed"]
and len(out_edges["directed"]) == 2
)
assert ("A", "E") in out_edges["undirected"]

admg = ADMG()
admg.add_edge("A", "E", admg.directed_edge_name)
admg.add_edge("H", "S", admg.directed_edge_name)
admg.add_edge("H", "E", admg.directed_edge_name)
admg.add_edge("H", "R", admg.directed_edge_name)

S = {"S"}
L = {"H"}

out_edges = pywhy_graphs.dag_to_mag(admg, L, S).edges()

dir_edges = list(out_edges["directed"])
assert ("A", "E") in out_edges["directed"] and len(out_edges["directed"]) == 1
assert ("E", "R") in out_edges["bidirected"]

admg = ADMG()
admg.add_edge("P", "S", admg.directed_edge_name)
admg.add_edge("S", "L", admg.directed_edge_name)
admg.add_edge("G", "S", admg.directed_edge_name)
admg.add_edge("G", "L", admg.directed_edge_name)
admg.add_edge("I", "S", admg.directed_edge_name)
admg.add_edge("J", "I", admg.directed_edge_name)
admg.add_edge("J", "S", admg.directed_edge_name)

S = set()
L = {"J"}

out_edges = pywhy_graphs.dag_to_mag(admg, L, S).edges()
dir_edges = list(out_edges["directed"])
assert (
("G", "S") in dir_edges
and ("G", "L") in dir_edges
and ("S", "L") in dir_edges
and ("I", "S") in dir_edges
and ("P", "S") in dir_edges
and len(dir_edges) == 5
)
Loading