Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add the ability to check if an edge is "visible" #119

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion pywhy_graphs/algorithms/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"dag_to_mag",
"is_maximal",
"all_vstructures",
"check_visibility"
]


Expand Down Expand Up @@ -826,7 +827,6 @@ def is_maximal(G, L: Optional[Set] = None, S: Optional[Set] = None):
continue
return True


def all_vstructures(G: nx.DiGraph, as_edges: bool = False):
"""Generate all v-structures in the graph.

Expand Down Expand Up @@ -855,3 +855,58 @@ def all_vstructures(G: nx.DiGraph, as_edges: bool = False):
else:
vstructs.add((p1, node, p2)) # type: ignore
return vstructs

def get_all_collider_paths(G : PAG, X, Y):

out = []

# find all the possible paths from X to Y with only bi-directed edges

bidirected_edge_graph = G.sub_bidirected_graph

X_descendants = set(G.sub_directed_graph.neigbors(X))

candidate_collider_path_nodes = set(bidirected_edge_graph.nodes).intersection(X_descendants)

if candidate_collider_path_nodes is None:
return out

for elem in candidate_collider_path_nodes:
out.extend(nx.all_simple_paths(G, elem, Y))

# for path in out:
# path.insert(0,X)

return out

def check_visibility(G: PAG, X: str, Y: str):

X_neighbors = set(G.neighbors(X))
Y_neighbors = set(G.neighbors(Y))

only_x_neighbors = X_neighbors - Y_neighbors


for elem in only_x_neighbors:
if G.has_edge(elem, X, G.directed_edge_name):
return True

all_nodes = set(G.nodes)

all_nodes.remove(X)


candidates = all_nodes - Y_neighbors

for elem in candidates:
collider_paths = get_all_collider_paths(G,elem,X)
for path in collider_paths:
for node in path:
if node in G.neighbors(Y):
continue
else:
return True

return False


33 changes: 31 additions & 2 deletions pywhy_graphs/algorithms/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import pytest

import pywhy_graphs
from pywhy_graphs import ADMG
from pywhy_graphs.algorithms import all_vstructures
from pywhy_graphs import ADMG, PAG
from pywhy_graphs.algorithms import all_vstructures, check_visibility


def test_convert_to_latent_confounder_errors():
Expand Down Expand Up @@ -496,3 +496,32 @@ def test_all_vstructures():
# Assert that the returned values are as expected
assert len(v_structs_edges) == 0
assert len(v_structs_tuples) == 0



def test_check_visibility():

# H <-> K <-> Z <-> X <- Y

pag = PAG()
pag.add_edge("Y", "X", pag.directed_edge_name)
pag.add_edge("Z", "X", pag.bidirected_edge_name)
pag.add_edge("Z", "K", pag.bidirected_edge_name)
pag.add_edge("K", "H", pag.bidirected_edge_name)

assert True == check_visibility(pag, "X", "Y")

pag = PAG()
pag.add_edge("Y", "X", pag.directed_edge_name)
pag.add_edge("Z", "X", pag.bidirected_edge_name)

assert True == check_visibility(pag, "X", "Y")

pag = PAG()
pag.add_edge("Y", "X", pag.directed_edge_name)
pag.add_edge("Z", "Y", pag.bidirected_edge_name)
pag.add_edge("Z", "K", pag.bidirected_edge_name)
pag.add_edge("K", "H", pag.bidirected_edge_name)

assert False == check_visibility(pag, "X", "Y")