Skip to content

Commit

Permalink
Merge main
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
adam2392 committed Sep 29, 2023
2 parents c8de7a3 + 9f3e202 commit 75c261c
Show file tree
Hide file tree
Showing 6 changed files with 474 additions and 5 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
poetry-version: [1.5.1]
steps:
- name: Checkout repository
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Setup Python 3.9
uses: actions/setup-python@v4
with:
Expand Down Expand Up @@ -65,7 +65,7 @@ jobs:
shell: bash
steps:
- name: Checkout repository
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
Expand Down Expand Up @@ -123,7 +123,7 @@ jobs:
shell: bash
steps:
- name: Checkout repository
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
Expand Down Expand Up @@ -168,7 +168,7 @@ jobs:
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Checkout repository
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pr_checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
run: |
echo "PR_NUMBER=${{ github.event.pull_request.number }}" >> $GITHUB_ENV
echo "TAGGED_MILESTONE=${{ github.event.pull_request.milestone.title }}" >> $GITHUB_ENV
- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
fetch-depth: '0'
- name: Check that CHANGELOG has been updated
Expand Down
3 changes: 3 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ causal graph operations.
.. autosummary::
:toctree: generated/

dag_to_mag
valid_mag
has_adc
inducing_path
is_valid_mec_graph
possible_ancestors
Expand Down
3 changes: 3 additions & 0 deletions doc/whats_new/v0.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ Version 0.2

Changelog
---------
- |Feature| Implement and test functions to validate an MAG and check the presence of almost directed cycles, by `Aryan Roy`_ (:pr:`91`)
- |Feature| Implement and test functions to convert a DAG to MAG, by `Aryan Roy`_ (:pr:`96`)
- |Feature| Add algorithms for interfacing with a selection diagram in ``pywhy_graphs.algorithms.multidomain``, by `Adam Li`_ (:pr:`88`)

Code and Documentation Contributors
Expand All @@ -34,4 +36,5 @@ Thanks to everyone who has contributed to the maintenance and improvement of
the project since version inception, including:

* `Adam Li`_
* `Aryan Roy`_

219 changes: 219 additions & 0 deletions pywhy_graphs/algorithms/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
"set_nodes_as_latent_confounders",
"is_valid_mec_graph",
"inducing_path",
"has_adc",
"valid_mag",
"dag_to_mag",
"is_maximal",
]


Expand Down Expand Up @@ -567,6 +571,9 @@ def inducing_path(G, node_x: Node, node_y: Node, L: Optional[Set] = None, S: Opt
if node_x == node_y:
raise ValueError("The source and destination nodes are the same.")

if (node_x in L) or (node_y in L) or (node_x in S) or (node_y in S):
return (False, [])

edges = G.edges()

# XXX: fix this when graphs are refactored to only check for directed/bidirected edge types
Expand Down Expand Up @@ -605,3 +612,215 @@ def inducing_path(G, node_x: Node, node_y: Node, L: Optional[Set] = None, S: Opt
break

return (path_exists, path)


def has_adc(G):
"""Check if a graph has an almost directed cycle (adc).
An almost directed cycle is a is a directed cycle containing
one bidirected edge. For example, ``A -> B -> C <-> A`` is an adc.
Parameters
----------
G : Graph
The graph.
Returns
-------
adc_present : bool
A boolean indicating whether an almost directed cycle is present or not.
"""

adc_present = False

biedges = G.bidirected_edges

for elem in G.nodes:
ancestors = nx.ancestors(G.sub_directed_graph(), elem)
descendants = nx.descendants(G.sub_directed_graph(), elem)
for elem in biedges:
if (elem[0] in ancestors and elem[1] in descendants) or (
elem[1] in ancestors and elem[0] in descendants
): # there is a bidirected edge from one of the ancestors to a descendant
return not adc_present

return adc_present


def valid_mag(G: ADMG, L: set = None, S: set = None):
"""Checks if the provided graph is a valid maximal ancestral graph (MAG).
A valid MAG as defined in :footcite:`Zhang2008` is a mixed edge graph that
only has directed and bi-directed edges, no directed or almost directed
cycles and no inducing paths between any two non-adjacent pair of nodes.
Parameters
----------
G : Graph
The graph.
Returns
-------
is_valid : bool
A boolean indicating whether the provided graph is a valid MAG or not.
"""

if L is None:
L = set()

if S is None:
S = set()

directed_sub_graph = G.sub_directed_graph()

all_nodes = set(G.nodes)

# check if there are any undirected edges or more than one edges b/w two nodes
for node in all_nodes:
nb = set(G.neighbors(node))
for elem in nb:
edge_data = G.get_edge_data(node, elem)
if edge_data["undirected"] is not None:
return False
elif (edge_data["bidirected"] is not None) and (edge_data["directed"] is not None):
return False

# check if there are any directed cyclces
try:
nx.find_cycle(directed_sub_graph) # raises a NetworkXNoCycle error
return False
except nx.NetworkXNoCycle:
pass

# check if there are any almost directed cycles
if has_adc(G): # if there is an ADC, it's not a valid MAG
return False

# check if there are any inducing paths between non-adjacent nodes

for source in all_nodes:
nb = set(G.neighbors(source))
cur_set = all_nodes - nb
cur_set.remove(source)
for dest in cur_set:
out = inducing_path(G, source, dest, L, S)
if out[0] is True:
return False

return True


def dag_to_mag(G, L: Set = None, S: Set = None):
"""Converts a DAG to a valid MAG.
The algorithm is defined in :footcite:`Zhang2008` on page 1877.
Parameters:
-----------
G : Graph
The graph.
L : Set
Nodes that are ignored on the path. Defaults to an empty set.
S : Set
Nodes that are always conditioned on. Defaults to an empty set.
Returns
-------
mag : Graph
The MAG.
"""

if L is None:
L = set()

if S is None:
S = set()

# for each pair of nodes find if they have an inducing path between them.
# only then will they be adjacent in the MAG.

all_nodes = set(G.nodes)
adj_nodes = []

for source in all_nodes:
copy_all = all_nodes.copy()
copy_all.remove(source)
for dest in copy_all:
out = inducing_path(G, source, dest, L, S)
if out[0] is True and {source, dest} not in adj_nodes:
adj_nodes.append({source, dest})

# find the ancestors of B U S (ansB) and A U S (ansA) for each pair of adjacent nodes

mag = ADMG()

for A, B in adj_nodes:

AuS = S.union(A)
BuS = S.union(B)

ansA: Set = set()
ansB: Set = set()

for node in AuS:
ansA = ansA.union(_directed_sub_graph_ancestors(G, node))

for node in BuS:
ansB = ansB.union(_directed_sub_graph_ancestors(G, node))

if A in ansB and B not in ansA:
# if A is in ansB and B is not in ansA, A -> B
mag.add_edge(A, B, mag.directed_edge_name)

elif A not in ansB and B in ansA:
# if B is in ansA and A is not in ansB, A <- B
mag.add_edge(B, A, mag.directed_edge_name)

elif A not in ansB and B not in ansA:
# if A is not in ansB and B is not in ansA, A <-> B
mag.add_edge(B, A, mag.bidirected_edge_name)

elif A in ansB and B in ansA:
# if A is in ansB and B is in ansA, A - B
mag.add_edge(B, A, mag.undirected_edge_name)

return mag


def is_maximal(G, L: Set = None, S: Set = None):
"""Checks to see if the graph is maximal.
Parameters:
-----------
G : Graph
The graph.
Returns
-------
is_maximal : bool
A boolean indicating whether the provided graph is maximal or not.
"""

if L is None:
L = set()

if S is None:
S = set()

all_nodes = set(G.nodes)
checked = set()
for source in all_nodes:
nb = set(G.neighbors(source))
cur_set = all_nodes - nb
cur_set.remove(source)
for dest in cur_set:
current_pair = frozenset({source, dest})
if current_pair not in checked:
checked.add(current_pair)
out = inducing_path(G, source, dest, L, S)
if out[0] is True:
return False
else:
continue
return True
Loading

0 comments on commit 75c261c

Please sign in to comment.