diff --git a/pywhy_graphs/algorithms/tests/test_pag.py b/pywhy_graphs/algorithms/tests/test_pag.py index 5642c5e8..d7815bcf 100644 --- a/pywhy_graphs/algorithms/tests/test_pag.py +++ b/pywhy_graphs/algorithms/tests/test_pag.py @@ -3,7 +3,7 @@ import pytest import pywhy_graphs -from pywhy_graphs import ADMG, PAG +from pywhy_graphs import PAG from pywhy_graphs.algorithms import ( discriminating_path, is_definite_noncollider, @@ -664,13 +664,12 @@ def test_pag_to_mag(): out_mag = pywhy_graphs.pag_to_mag(pag) - out_edges = list(out_mag.edges()["directed"]) assert ( - ((("A", "B") in out_edges) or (("B", "A") in out_edges)) - and ((("A", "C") in out_edges) or (("C", "A") in out_edges)) - and (("A", "D") in out_edges) - and (("B", "D") in out_edges) - and (("C", "D") in out_edges) + ((out_mag.has_edge("A", "B")) or (out_mag.has_edge("B", "A"))) + and ((out_mag.has_edge("A", "C")) or (out_mag.has_edge("C", "A"))) + and (out_mag.has_edge("A", "D")) + and (out_mag.has_edge("B", "D")) + and (out_mag.has_edge("C", "D")) ) pag = PAG() @@ -683,14 +682,8 @@ def test_pag_to_mag(): out_mag = pywhy_graphs.pag_to_mag(pag) - mag = ADMG() - mag.add_edge("B", "A", mag.directed_edge_name) - mag.add_edge("D", "A", mag.directed_edge_name) - mag.add_edge("D", "B", mag.directed_edge_name) - - out_edges = list(out_mag.edges()["directed"]) assert ( - (("B", "A") in out_edges) - and (("D", "A") in out_edges) - and ((("D", "B") in out_edges) or (("B", "D") in out_edges)) + out_mag.has_edge("B", "A") + and out_mag.has_edge("D", "A") + and (out_mag.has_edge("D", "B") or out_mag.has_edge("B", "D")) )