Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add the ability to convert a PAG to MAG #93

Merged
merged 24 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
poetry-version: [1.5.1]
steps:
- name: Checkout repository
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Setup Python 3.9
uses: actions/setup-python@v4
with:
Expand Down Expand Up @@ -65,7 +65,7 @@ jobs:
shell: bash
steps:
- name: Checkout repository
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
Expand Down Expand Up @@ -123,7 +123,7 @@ jobs:
shell: bash
steps:
- name: Checkout repository
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
Expand Down Expand Up @@ -168,7 +168,7 @@ jobs:
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Checkout repository
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pr_checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
run: |
echo "PR_NUMBER=${{ github.event.pull_request.number }}" >> $GITHUB_ENV
echo "TAGGED_MILESTONE=${{ github.event.pull_request.milestone.title }}" >> $GITHUB_ENV
- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
fetch-depth: '0'
- name: Check that CHANGELOG has been updated
Expand Down
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ causal graph operations.
.. autosummary::
:toctree: generated/

dag_to_mag
valid_mag
has_adc
inducing_path
Expand Down
2 changes: 2 additions & 0 deletions doc/whats_new/v0.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ Version 0.2
Changelog
---------
- |Feature| Implement and test functions to validate an MAG and check the presence of almost directed cycles, by `Aryan Roy`_ (:pr:`91`)
- |Feature| Implement and test functions to convert a DAG to MAG, by `Aryan Roy`_ (:pr:`96`)
- |Feature| Implement and test functions to convert a PAG to MAG, by `Aryan Roy`_ (:pr:`93`)

Code and Documentation Contributors
-----------------------------------
Expand Down
120 changes: 120 additions & 0 deletions pywhy_graphs/algorithms/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"inducing_path",
"has_adc",
"valid_mag",
"dag_to_mag",
"is_maximal",
]


Expand Down Expand Up @@ -567,6 +569,9 @@ def inducing_path(G, node_x: Node, node_y: Node, L: Set = None, S: Set = None):
if node_x == node_y:
raise ValueError("The source and destination nodes are the same.")

if (node_x in L) or (node_y in L) or (node_x in S) or (node_y in S):
return (False, [])

edges = G.edges()

# XXX: fix this when graphs are refactored to only check for directed/bidirected edge types
Expand Down Expand Up @@ -703,3 +708,118 @@ def valid_mag(G: ADMG, L: set = None, S: set = None):
return False

return True


def dag_to_mag(G, L: Set = None, S: Set = None):
"""Converts a DAG to a valid MAG.

The algorithm is defined in :footcite:`Zhang2008` on page 1877.

Parameters:
-----------
G : Graph
The graph.
L : Set
Nodes that are ignored on the path. Defaults to an empty set.
S : Set
Nodes that are always conditioned on. Defaults to an empty set.

Returns
-------
mag : Graph
The MAG.
"""

if L is None:
L = set()

if S is None:
S = set()

# for each pair of nodes find if they have an inducing path between them.
# only then will they be adjacent in the MAG.

all_nodes = set(G.nodes)
adj_nodes = []

for source in all_nodes:
copy_all = all_nodes.copy()
copy_all.remove(source)
for dest in copy_all:
out = inducing_path(G, source, dest, L, S)
if out[0] is True and {source, dest} not in adj_nodes:
adj_nodes.append({source, dest})

# find the ancestors of B U S (ansB) and A U S (ansA) for each pair of adjacent nodes

mag = ADMG()

for A, B in adj_nodes:

AuS = S.union(A)
BuS = S.union(B)

ansA: Set = set()
ansB: Set = set()

for node in AuS:
ansA = ansA.union(_directed_sub_graph_ancestors(G, node))

for node in BuS:
ansB = ansB.union(_directed_sub_graph_ancestors(G, node))

if A in ansB and B not in ansA:
# if A is in ansB and B is not in ansA, A -> B
mag.add_edge(A, B, mag.directed_edge_name)

elif A not in ansB and B in ansA:
# if B is in ansA and A is not in ansB, A <- B
mag.add_edge(B, A, mag.directed_edge_name)

elif A not in ansB and B not in ansA:
# if A is not in ansB and B is not in ansA, A <-> B
mag.add_edge(B, A, mag.bidirected_edge_name)

elif A in ansB and B in ansA:
# if A is in ansB and B is in ansA, A - B
mag.add_edge(B, A, mag.undirected_edge_name)

return mag


def is_maximal(G, L: Set = None, S: Set = None):
"""Checks to see if the graph is maximal.

Parameters:
-----------
G : Graph
The graph.

Returns
-------
is_maximal : bool
A boolean indicating whether the provided graph is maximal or not.
"""

if L is None:
L = set()

if S is None:
S = set()

all_nodes = set(G.nodes)
checked = set()
for source in all_nodes:
nb = set(G.neighbors(source))
cur_set = all_nodes - nb
cur_set.remove(source)
for dest in cur_set:
current_pair = frozenset({source, dest})
if current_pair not in checked:
checked.add(current_pair)
out = inducing_path(G, source, dest, L, S)
if out[0] is True:
return False
else:
continue
return True
Loading