Skip to content

Commit

Permalink
Creates VisualizationNodeModifiers enum
Browse files Browse the repository at this point in the history
This is to house the possible visual modifiers for a node.
This is an enum to (1) keep them grouped, (2) they were
being used in a boolean fashion, so the presence of the enum
is enough.
  • Loading branch information
skrawcz authored and elijahbenizzy committed May 25, 2023
1 parent 79c3756 commit 0122dce
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 23 deletions.
14 changes: 7 additions & 7 deletions hamilton/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,11 +464,11 @@ def visualize_execution(
_final_vars = self._create_final_vars(final_vars)
nodes, user_nodes = self.graph.get_upstream_nodes(_final_vars, inputs)
self.validate_inputs(user_nodes, inputs, nodes)
node_modifiers = {fv: {"is_output": True} for fv in _final_vars}
node_modifiers = {fv: {graph.VisualizationNodeModifiers.IS_OUTPUT} for fv in _final_vars}
for user_node in user_nodes:
if user_node.name not in node_modifiers:
node_modifiers[user_node.name] = {}
node_modifiers[user_node.name]["is_user_input"] = True
node_modifiers[user_node.name] = set()
node_modifiers[user_node.name].add(graph.VisualizationNodeModifiers.IS_USER_INPUT)
try:
return self.graph.display(
nodes.union(user_nodes),
Expand Down Expand Up @@ -556,7 +556,7 @@ def display_upstream_of(
upstream_nodes, user_nodes = self.graph.get_upstream_nodes(list(node_names))
node_modifiers = {}
for n in user_nodes:
node_modifiers[n.name] = {"is_user_input": True}
node_modifiers[n.name] = {graph.VisualizationNodeModifiers.IS_USER_INPUT}
try:
return self.graph.display(
upstream_nodes,
Expand Down Expand Up @@ -663,7 +663,7 @@ def visualize_path_between(
node_modifiers = {}
for n in self.graph.get_nodes():
if n.user_defined:
node_modifiers[n.name] = {"is_user_input": True}
node_modifiers[n.name] = {graph.VisualizationNodeModifiers.IS_USER_INPUT}

# create nodes that constitute the path
nodes_for_path = self._get_nodes_between(upstream_node_name, downstream_node_name)
Expand All @@ -674,8 +674,8 @@ def visualize_path_between(
# add is path for node_modifier's dict
for n in nodes_for_path:
if n.name not in node_modifiers:
node_modifiers[n.name] = {}
node_modifiers[n.name]["is_path"] = True
node_modifiers[n.name] = set()
node_modifiers[n.name].add(graph.VisualizationNodeModifiers.IS_PATH)
try:
return self.graph.display(
nodes_for_path,
Expand Down
27 changes: 18 additions & 9 deletions hamilton/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Note: one should largely consider the code in this module to be "private".
"""
import logging
from enum import Enum
from types import ModuleType
from typing import Any, Callable, Collection, Dict, List, Optional, Set, Tuple, Type

Expand All @@ -17,6 +18,14 @@
logger = logging.getLogger(__name__)


class VisualizationNodeModifiers(Enum):
"""Enum of all possible node modifiers for visualization."""

IS_OUTPUT = 1
IS_PATH = 2
IS_USER_INPUT = 3


def add_dependency(
func_node: node.Node,
func_name: str,
Expand Down Expand Up @@ -97,7 +106,7 @@ def create_graphviz_graph(
nodes: Set[node.Node],
comment: str,
graphviz_kwargs: dict,
node_modifiers: Dict[str, dict],
node_modifiers: Dict[str, Set[VisualizationNodeModifiers]],
strictly_display_only_nodes_passed_in: bool,
) -> "graphviz.Digraph": # noqa: F821
"""Helper function to create a graphviz graph.
Expand All @@ -119,12 +128,13 @@ def create_graphviz_graph(
other_args = {}
# checks if the node has any modifiers
if n.name in node_modifiers:
modifiers = node_modifiers[n.name]
# if node is an output, then modify the node to be a rectangle
if node_modifiers[n.name].get("is_output"):
if VisualizationNodeModifiers.IS_OUTPUT in modifiers:
other_args["shape"] = "rectangle"
if node_modifiers[n.name].get("is_path"):
if VisualizationNodeModifiers.IS_PATH in modifiers:
other_args["color"] = "red"
if node_modifiers[n.name].get("is_user_input"):
if VisualizationNodeModifiers.IS_USER_INPUT in modifiers:
other_args["style"] = "dashed"
label = f"Input: {n.name}"
digraph.node(n.name, label=label, **other_args)
Expand All @@ -136,10 +146,9 @@ def create_graphviz_graph(
if (
d not in nodes
and d.name in node_modifiers
and node_modifiers[d.name].get("is_user_input")
and VisualizationNodeModifiers.IS_USER_INPUT in node_modifiers[d.name]
):
digraph.node(d.name, label=f"Input: {d.name}", style="dashed")
# print(f"Adding edge from {d.name} to {n.name}")
digraph.edge(d.name, n.name)
return digraph

Expand Down Expand Up @@ -223,7 +232,7 @@ def display_all(
node_modifiers = {}
for n in self.nodes.values():
if n.user_defined:
node_modifiers[n.name] = {"is_user_input": True}
node_modifiers[n.name] = {VisualizationNodeModifiers.IS_USER_INPUT}
all_nodes.add(n)
if render_kwargs is None:
render_kwargs = {}
Expand Down Expand Up @@ -272,7 +281,7 @@ def display(
output_file_path: Optional[str] = "test-output/graph.gv",
render_kwargs: dict = None,
graphviz_kwargs: dict = None,
node_modifiers: Dict[str, dict] = None,
node_modifiers: Dict[str, Set[VisualizationNodeModifiers]] = None,
strictly_display_only_passed_in_nodes: bool = False,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Function to display the graph represented by the passed in nodes.
Expand All @@ -283,7 +292,7 @@ def display(
:param graphviz_kwargs: kwargs to be passed to the graphviz graph object to configure it.
e.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image.
:param node_modifiers: a dictionary of node names to a dictionary of attributes to modify.
e.g. {'node_name': {'is_user_input': True}} will set the node named 'node_name' to be a user input.
e.g. {'node_name': {NodeModifiers.IS_USER_INPUT}} will set the node named 'node_name' to be a user input.
:param strictly_display_only_passed_in_nodes: if True, only display the nodes passed in. Else defaults to
displaying also what nodes a node depends on (i.e. all nodes that feed into it).
:return: the graphviz graph object if it was created. None if not.
Expand Down
14 changes: 7 additions & 7 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,11 +497,11 @@ def test_function_graph_has_cycles_false():
def test_function_graph_display():
"""Tests that display saves a file"""
fg = graph.FunctionGraph(tests.resources.dummy_functions, config={"b": 1, "c": 2})
node_modifiers = {"B": {"is_output": True}}
node_modifiers = {"B": {graph.VisualizationNodeModifiers.IS_OUTPUT}}
all_nodes = set()
for n in fg.get_nodes():
if n.user_defined:
node_modifiers[n.name] = {"is_user_input": True}
node_modifiers[n.name] = {graph.VisualizationNodeModifiers.IS_USER_INPUT}
all_nodes.add(n)
# hack of a test -- but it works... sort the lines and match them up.
# why? because for some reason given the same graph, the output file isn't deterministic.
Expand Down Expand Up @@ -533,10 +533,10 @@ def test_function_graph_display_without_saving():
"""Tests that display works when None is passed in for path"""
fg = graph.FunctionGraph(tests.resources.dummy_functions, config={"b": 1, "c": 2})
all_nodes = set()
node_modifiers = {"B": {"is_output": True}}
node_modifiers = {"B": {graph.VisualizationNodeModifiers.IS_OUTPUT}}
for n in fg.get_nodes():
if n.user_defined:
node_modifiers[n.name] = {"is_user_input": True}
node_modifiers[n.name] = {graph.VisualizationNodeModifiers.IS_USER_INPUT}
all_nodes.add(n)
digraph = fg.display(all_nodes, None, node_modifiers=node_modifiers)
assert digraph is not None
Expand All @@ -551,9 +551,9 @@ def test_create_graphviz_graph():
nodes, user_nodes = fg.get_upstream_nodes(["A", "B", "C"])
nodez = nodes.union(user_nodes)
node_modifiers = {
"b": {"is_user_input": True},
"c": {"is_user_input": True},
"B": {"is_output": True},
"b": {graph.VisualizationNodeModifiers.IS_USER_INPUT},
"c": {graph.VisualizationNodeModifiers.IS_USER_INPUT},
"B": {graph.VisualizationNodeModifiers.IS_OUTPUT},
}
# hack of a test -- but it works... sort the lines and match them up.
# why? because for some reason given the same graph, the output file isn't deterministic.
Expand Down

0 comments on commit 0122dce

Please sign in to comment.