Skip to content

Commit

Permalink
Fix style
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
adam2392 committed Aug 8, 2023
1 parent ec5af70 commit c8de7a3
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 23 deletions.
14 changes: 9 additions & 5 deletions examples/simulations/plot_discrete_causal_bayesian_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
# Import the required libraries
# -----------------------------
import networkx as nx
from pywhy_graphs.functional.discrete import make_random_discrete_graph
from pgmpy.factors.discrete.CPD import TabularCPD

from pywhy_graphs.functional import sample_from_graph
from pywhy_graphs.functional.discrete import make_random_discrete_graph
from pywhy_graphs.viz import draw
from pgmpy.factors.discrete.CPD import TabularCPD


# define a helper function to print the full CPD
def print_full(cpd):
Expand All @@ -37,6 +39,7 @@ def print_full(cpd):
print(cpd)
TabularCPD._truncate_strtable = backup


# %%
# Construct the causal graph
# --------------------------
Expand Down Expand Up @@ -86,9 +89,10 @@ def print_full(cpd):
node_dict = G.nodes["C"]

# We see that each node is fully defined given a conditional probability table, stored as a node
# attribute under the keyword 'cpd'. For more information on the CPD object, see pgmpy's documentation
# on :class:`pgmpy.factors.discrete.CPD.TabularCPD`. Note this is in contrast with what node attributes
# are required in general for simulating data from a causal graph in pywhy-graphs.
# attribute under the keyword 'cpd'. For more information on the CPD object, see
# pgmpy's documentation on :class:`pgmpy.factors.discrete.CPD.TabularCPD`. Note this
# is in contrast with what node attributes are required in general for simulating data
# from a causal graph in pywhy-graphs.
print_full(node_dict["cpd"])

# %%
Expand Down
1 change: 0 additions & 1 deletion examples/simulations/plot_linear_gaussian_causal_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,3 @@
:ref:`ex-discrete-cbn`. Consider reading the user-guide, :ref:`functional-causal-graphical-models`
to understand how an arbitrary functional relationships are encoded in a causal graph.
"""

3 changes: 1 addition & 2 deletions pywhy_graphs/classes/augmented.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,11 @@ def directed_edge_name(self) -> str:
def n_domains(self):
domains = set()
for node_dict in self.nodes(data=True):
domain_ids = node_dict.get('domain_ids', None)
domain_ids = node_dict.get("domain_ids", None)
if domain_ids is not None:
domains.add(domain for domain in domain_ids)

return len(domains)


def _verify_augmentednode_dict(self):
# verify validity of F nodes
Expand Down
2 changes: 1 addition & 1 deletion pywhy_graphs/functional/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def sample_from_graph(
else:
directed_G = G

print('inside: ', directed_G.nodes(data=True))
print("inside: ", directed_G.nodes(data=True))
# check input
_check_input_graph(directed_G)

Expand Down
2 changes: 1 addition & 1 deletion pywhy_graphs/functional/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def apply_linear_soft_intervention(
of the target nodes. That is, the soft intervention, perturbs the
exogenous noise of the target nodes.
"""
if not G.graph.get("functional", 'linear_gaussian'):
if not G.graph.get("functional", "linear_gaussian"):
raise ValueError("The input graph must be a linear Gaussian graph.")
if not all(target in G.nodes for target in targets):
raise ValueError(f"All targets {targets} must be in the graph: {G.nodes}.")
Expand Down
12 changes: 7 additions & 5 deletions pywhy_graphs/functional/multidomain.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,18 @@ def apply_domain_shift(G, node, domain_ids, exogenous_distribution=None, random_
rng = np.random.default_rng(random_state)
exogenous_distribution = lambda: rng.standard_normal()

# determine which S-node the domain IDs corresond to
# determine which S-node the domain IDs correspond to
snode = G.domain_ids_to_snode[domain_ids]

if not G.has_edge(snode, node):
raise RuntimeError(f'Node {node} does not have an S-node {snode} pointing to it for domain'
f'pairs {domain_ids}.')

raise RuntimeError(
f"Node {node} does not have an S-node {snode} pointing to it for domain"
f"pairs {domain_ids}."
)

# now add a new exogenous distribution for the node
domain_id = domain_ids[1]
G.nodes[node]['domain'][domain_id]["exogenous_distribution"] = lambda: exogenous_distribution()
G.nodes[node]["domain"][domain_id]["exogenous_distribution"] = lambda: exogenous_distribution()
return G


Expand Down
21 changes: 13 additions & 8 deletions pywhy_graphs/viz/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def draw(
direction: Optional[str] = None,
pos: Optional[dict] = None,
name: Optional[str] = None,
node_order: Optional[List] = None,
shape: str = "square",
**attrs,
):
Expand Down Expand Up @@ -123,6 +124,9 @@ def draw(
if direction == "LR":
dot.graph_attr["rankdir"] = direction

if node_order is None:
node_order = G.nodes

circle_edges = None
directed_edges = None
undirected_edges = None
Expand All @@ -136,10 +140,17 @@ def draw(
if hasattr(G, "undirected_edges"):
undirected_edges = G.undirected_edges
elif isinstance(G, nx.Graph) and not G.is_directed():
undirected_edges = G.edges()
undirected_edges = G.edges()
if hasattr(G, "bidirected_edges"):
bidirected_edges = G.bidirected_edges

for v in node_order:
child = str(v)
if pos and pos.get(v) is not None:
dot.node(child, shape=shape, height=".5", width=".5", pos=f"{pos[v][0]},{pos[v][1]}!")
else:
dot.node(child, shape=shape, height=".5", width=".5")

# draw PAG edges and keep track of the circular endpoints found
dot, found_circle_sibs = _draw_circle_edges(
dot,
Expand All @@ -157,14 +168,8 @@ def draw(

# only need to draw directed edges now, but directed_G can be a nx.Graph
if hasattr(directed_G, "predecessors"):
for v in G.nodes:
for v in node_order:
child = str(v)
if pos and pos.get(v) is not None:
dot.node(
child, shape=shape, height=".5", width=".5", pos=f"{pos[v][0]},{pos[v][1]}!"
)
else:
dot.node(child, shape=shape, height=".5", width=".5")

for parent in directed_G.predecessors(v):
if parent == v or not directed_G.has_edge(parent, v):
Expand Down
13 changes: 13 additions & 0 deletions pywhy_graphs/viz/tests/test_draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,19 @@ def test_draw_pos_contains_more_nodes():
assert "pos=" not in re.search(r"\tz \[(.*)\]", dot_body_text).groups()[0]


def test_draw_does_not_show_undirected():
graph = nx.DiGraph()

graph.add_edge("x", "y")
graph.add_edge("y", "z")

dot = draw(graph)
dot_body_text = "".join(dot.body)

# there should not be a drawn undirected edges
assert "dir=none" not in dot_body_text


def test_draw_pos_with_pag():
"""
Ensure the Graphviz pos="x,y!" attribute is generated by the draw function
Expand Down

0 comments on commit c8de7a3

Please sign in to comment.