From c62db2b2e30fa0d638f783e8f736ae742fdf90f5 Mon Sep 17 00:00:00 2001 From: zilto Date: Wed, 1 Nov 2023 19:55:31 -0400 Subject: [PATCH 01/10] improved DAG viz from graph.create_graphviz_graph() --- hamilton/graph.py | 265 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 216 insertions(+), 49 deletions(-) diff --git a/hamilton/graph.py b/hamilton/graph.py index bb84ecbfe..8d6fd730d 100644 --- a/hamilton/graph.py +++ b/hamilton/graph.py @@ -20,8 +20,6 @@ logger = logging.getLogger(__name__) -PATH_COLOR = "red" - class VisualizationNodeModifiers(Enum): """Enum of all possible node modifiers for visualization.""" @@ -140,6 +138,9 @@ def create_graphviz_graph( graphviz_kwargs: dict, node_modifiers: Dict[str, Set[VisualizationNodeModifiers]], strictly_display_only_nodes_passed_in: bool, + left_to_right: bool = True, + hide_inputs: bool = False, + deduplicate_inputs: bool = False, ) -> "graphviz.Digraph": # noqa: F821 """Helper function to create a graphviz graph. @@ -150,67 +151,233 @@ def create_graphviz_graph( :param node_modifiers: A dictionary of node names to dictionaries of node attributes to modify. :param strictly_display_only_nodes_passed_in: If True, only display the nodes passed in. Else defaults to displaying also what nodes a node depends on (i.e. all nodes that feed into it). + :param left_to_right: If True, the DAG is displayed from left to right. Else, top to bottom. This value can + be overriden by the value of `graphviz_kwargs['graph_attr']['rankdir']` :return: a graphviz.Digraph; use this to render/save a graph representation. """ + PATH_COLOR = "#7A3B69" + import graphviz - digraph = graphviz.Digraph(comment=comment, **graphviz_kwargs) + def _get_node_label(n: node.Node, name=None, type_=None) -> str: + name = n.name if name is None else name + type_ = n.type.__name__ if type_ is None else type_ + return f"<{name}

{type_}>" + + def _get_input_label(input_nodes: frozenset[node.Node]) -> str: + rows = [f"{dep.name}{dep.type.__name__}" for dep in input_nodes] + return f"<{''.join(rows)}
>" + + def _get_node_type(n: node.Node) -> str: + if ( + n._node_source == node.NodeType.EXTERNAL + and n._originating_functions is None + and n._depended_on_by + ): + return "input" + elif ( + n._node_source == node.NodeType.EXTERNAL + and n._originating_functions is None + and not n._depended_on_by + ): + return "config" + else: + return "function" + + def _get_node_style(node_type: str): + fontname = "Helvetica" + + if node_type == "config": + node_style = dict( + shape="note", + style="", + fontname=fontname, + ) + elif node_type == "input": + node_style = dict( + shape="rectangle", + margin="0.15", + style="dashed", + fontname=fontname, + ) + elif node_type == "materializer": + node_style = dict( + shape="cylinder", + margin="0.15,0.1", + fontname=fontname, + ) + else: # this is a function or else + node_style = dict( + shape="rectangle", + margin="0.15", + style="rounded,filled", + fillcolor="#b4d8e4", + fontname=fontname, + ) + + return node_style + + def _get_function_modifier_style(modifier: str): + if modifier == "output": + modifier_style = dict(fillcolor="#FFC857") + elif modifier == "collect": + modifier_style = dict(peripheries="2", color="#EA5556") + elif modifier == "expand": + modifier_style = dict(peripheries="2", color="#56E39F") + elif modifier == "override": + modifier_style = dict(style="filled,diagonals") + elif modifier == "materializer": + modifier_style = dict(shape="cylinder") + else: + modifier_style = dict() + + return modifier_style + + def _get_edge_style(node_type: str): + if node_type == "expand": + edge_style = dict( + dir="both", + arrowhead="crow", + arrowtail="none", + ) + elif node_type == "collect": + edge_style = dict(dir="both", arrowtail="crow") + else: + edge_style = dict() + + return edge_style + + def _get_legend(node_types): + legend_subgraph = graphviz.Digraph( + name="cluster_legend", # needs to start with `cluster_` for graphviz layout + graph_attr=dict( + label="Legend", + rank="same", + fontname="helvetica", + ), + ) + + sorted_types = [ + "config", + "input", + "materializer", + "function", + "output", + "override", + "expand", + "collect", + ] + + for node_type in sorted(node_types, key=lambda t: sorted_types.index(t)): + node_style = _get_node_style(node_type) + modifier_style = _get_function_modifier_style(node_type) + node_style.update(**modifier_style) + legend_subgraph.node(name=node_type, **node_style) + + return legend_subgraph + + rankdir = "LR" if left_to_right is False else "TB" + # handle default values in nested dict + digraph_attr = dict( + comment=comment, + graph_attr=dict( + rankdir=rankdir, + ranksep="0.4", + compound="true", + concentrate="true", + ), + ) + digraph_attr.update(**graphviz_kwargs) + digraph = graphviz.Digraph(**digraph_attr) + + # create nodes + seen_node_types = set() for n in nodes: - label = n.name - other_args = {} - # checks if the node has any modifiers - if n.name in node_modifiers: - modifiers = node_modifiers[n.name] - # if node is an output, then modify the node to be a rectangle + label = _get_node_label(n) + node_type = _get_node_type(n) + seen_node_types.add(node_type) + if node_type == "input": + continue + + node_style = _get_node_style(node_type) + + # prefer having the conditions explicit for now since they rely on + # VisualizationNodeModifiers and node.Node.node_role + if modifiers := node_modifiers.get(n.name): if VisualizationNodeModifiers.IS_OUTPUT in modifiers: - other_args["shape"] = "rectangle" + modifier_style = _get_function_modifier_style("output") + node_style.update(**modifier_style) + seen_node_types.add("output") + + if VisualizationNodeModifiers.IS_OVERRIDE in modifiers: + modifier_style = _get_function_modifier_style("override") + node_style.update(**modifier_style) + seen_node_types.add("override") + if VisualizationNodeModifiers.IS_PATH in modifiers: - other_args["color"] = PATH_COLOR + node_style["color"] = PATH_COLOR - if VisualizationNodeModifiers.IS_USER_INPUT in modifiers: - other_args["style"] = "dashed" - label = f"Input: {n.name}" + if n.node_role == node.NodeType.EXPAND: + modifier_style = _get_function_modifier_style("expand") + node_style.update(**modifier_style) + seen_node_types.add("expand") - if VisualizationNodeModifiers.IS_OVERRIDE in modifiers: - other_args["style"] = "dashed" - label = f"Override: {n.name}" - is_expand_node = n.node_role == node.NodeType.EXPAND - is_collect_node = n.node_role == node.NodeType.COLLECT + if n.node_role == node.NodeType.COLLECT: + modifier_style = _get_function_modifier_style("collect") + node_style.update(**modifier_style) + seen_node_types.add("collect") + + if n._tags.get("hamilton.data_saver"): + materializer_type = n._tags["hamilton.data_saver.classname"] + label = _get_node_label(n, type_=materializer_type) + modifier_style = _get_function_modifier_style("materializer") + node_style.update(**modifier_style) + seen_node_types.add("materializer") - if is_collect_node or is_expand_node: - other_args["peripheries"] = "2" - digraph.node(n.name, label=label, **other_args) + digraph.node(n.name, label=label, **node_style) + + # create edges + input_sets = dict() + for n in nodes: + input_nodes = set() - for n in list(nodes): for d in n.dependencies: if strictly_display_only_nodes_passed_in and d not in nodes: continue - if ( - d not in nodes - and d.name in node_modifiers - and VisualizationNodeModifiers.IS_USER_INPUT in node_modifiers[d.name] - ): - digraph.node(d.name, label=f"Input: {d.name}", style="dashed") - from_modifiers = node_modifiers.get(d.name, set()) - to_modifiers = node_modifiers.get(n.name, set()) - other_args = {} - if ( - VisualizationNodeModifiers.IS_PATH in from_modifiers - and VisualizationNodeModifiers.IS_PATH in to_modifiers - ): - other_args["color"] = PATH_COLOR - is_collect_edge = n.node_role == node.NodeType.COLLECT - is_expand_edge = d.node_role == node.NodeType.EXPAND - - if is_collect_edge: - other_args["dir"] = "both" - other_args["arrowtail"] = "crow" - - if is_expand_edge: - other_args["dir"] = "both" - other_args["arrowhead"] = "crow" - other_args["arrowtail"] = "none" - digraph.edge(d.name, n.name, **other_args) + + node_type = _get_node_type(d) + if node_type == "input": + input_nodes.add(d) + continue + + edge_style = _get_edge_style(d) + digraph.edge(d.name, n.name, **edge_style) + + # skip input node creation + if hide_inputs: + continue + + if len(input_nodes) > 0: + input_node_name = f"{n.name}_inputs" + + # following block is for input node deduplication + input_nodes = frozenset(input_nodes) + if existing_input_name := input_sets.get(input_nodes): + digraph.edge(existing_input_name, n.name) + continue + + # allow duplicate input nodes by never storing keys + if deduplicate_inputs: + input_sets[input_nodes] = input_node_name + + # create input node + node_label = _get_input_label(input_nodes) + node_style = _get_node_style("input") + digraph.node(name=input_node_name, label=node_label, **node_style) + # create edge for input node + digraph.edge(input_node_name, n.name) + + digraph.subgraph(_get_legend(seen_node_types)) return digraph From d7e368dda88f4d7944060e973cf42275afd654ee Mon Sep 17 00:00:00 2001 From: zilto Date: Thu, 2 Nov 2023 11:04:03 -0400 Subject: [PATCH 02/10] added docstrings, comments, and edge logic --- hamilton/graph.py | 112 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 83 insertions(+), 29 deletions(-) diff --git a/hamilton/graph.py b/hamilton/graph.py index 8d6fd730d..6943d2c10 100644 --- a/hamilton/graph.py +++ b/hamilton/graph.py @@ -70,7 +70,9 @@ def add_dependency( def update_dependencies( - nodes: Dict[str, node.Node], adapter: base.HamiltonGraphAdapter, reset_dependencies: bool = True + nodes: Dict[str, node.Node], + adapter: base.HamiltonGraphAdapter, + reset_dependencies: bool = True, ): """Adds dependencies to a dictionary of nodes. If in_place is False, it will deepcopy the dict + nodes and return that. Otherwise it will @@ -138,7 +140,7 @@ def create_graphviz_graph( graphviz_kwargs: dict, node_modifiers: Dict[str, Set[VisualizationNodeModifiers]], strictly_display_only_nodes_passed_in: bool, - left_to_right: bool = True, + orient: str = "LR", hide_inputs: bool = False, deduplicate_inputs: bool = False, ) -> "graphviz.Digraph": # noqa: F821 @@ -151,24 +153,49 @@ def create_graphviz_graph( :param node_modifiers: A dictionary of node names to dictionaries of node attributes to modify. :param strictly_display_only_nodes_passed_in: If True, only display the nodes passed in. Else defaults to displaying also what nodes a node depends on (i.e. all nodes that feed into it). - :param left_to_right: If True, the DAG is displayed from left to right. Else, top to bottom. This value can - be overriden by the value of `graphviz_kwargs['graph_attr']['rankdir']` + :param orient: `LR` stands for "left to right". Accepted values are TB, LR, BT, RL. + `orient` will be overwridden by the value of `graphviz_kwargs['graph_attr']['rankdir']` + see (https://graphviz.org/docs/attr-types/rankdir/) + :param hide_inputs: If True, no input nodes are displayed. + :param deduplicate_inputs: If True, remove duplicate input nodes. + Can improve readability depending on the specifics of the DAG. :return: a graphviz.Digraph; use this to render/save a graph representation. """ PATH_COLOR = "#7A3B69" import graphviz - def _get_node_label(n: node.Node, name=None, type_=None) -> str: + def _get_node_label( + n: node.Node, + name: Optional[str] = None, + type_: Optional[str] = None, + ) -> str: + """Get a graphviz HTML-like node label. It uses the DAG node + name and type but values can be overridden. Overriding is currently + used for materializers since `type_` is stored in n._tags. + + ref: https://graphviz.org/doc/info/shapes.html#html + """ name = n.name if name is None else name type_ = n.type.__name__ if type_ is None else type_ return f"<{name}

{type_}>" def _get_input_label(input_nodes: frozenset[node.Node]) -> str: + """Get a graphviz HTML-like node label formatted as a table. + Each row is a different input node with one column containing + the name and the other the type. + ref: https://graphviz.org/doc/info/shapes.html#html + """ rows = [f"{dep.name}{dep.type.__name__}" for dep in input_nodes] return f"<{''.join(rows)}
>" def _get_node_type(n: node.Node) -> str: + """Get the node type of a DAG node. + + Input: is external, doesn't originate from a function, functions depend on it + Config: is external, doesn't originate from a function, no function depedends on it + Function: others + """ if ( n._node_source == node.NodeType.EXTERNAL and n._originating_functions is None @@ -184,7 +211,10 @@ def _get_node_type(n: node.Node) -> str: else: return "function" - def _get_node_style(node_type: str): + def _get_node_style(node_type: str) -> dict[str, str]: + """Get the style of a node type. + Graphviz needs values to be strings. + """ fontname = "Helvetica" if node_type == "config": @@ -217,7 +247,11 @@ def _get_node_style(node_type: str): return node_style - def _get_function_modifier_style(modifier: str): + def _get_function_modifier_style(modifier: str) -> dict[str, str]: + """Get the style of a modifier. The dictionary returned + is used to overwrite values of the base node style. + Graphviz needs values to be strings. + """ if modifier == "output": modifier_style = dict(fillcolor="#FFC857") elif modifier == "collect": @@ -233,26 +267,35 @@ def _get_function_modifier_style(modifier: str): return modifier_style - def _get_edge_style(node_type: str): - if node_type == "expand": - edge_style = dict( + def _get_edge_style(from_type: str, to_type: str) -> dict: + """ + + Graphviz needs values to be strings. + """ + edge_style = dict() + + if from_type == "expand": + print(from_type, to_type) + edge_style.update( dir="both", arrowhead="crow", arrowtail="none", ) - elif node_type == "collect": - edge_style = dict(dir="both", arrowtail="crow") - else: - edge_style = dict() + + if to_type == "collect": + edge_style.update(dir="both", arrowtail="crow") return edge_style def _get_legend(node_types): + """Create a visualization legend as a graphviz subgraph. The legend includes the + node types and modifiers presente in the visualization. + """ legend_subgraph = graphviz.Digraph( - name="cluster_legend", # needs to start with `cluster_` for graphviz layout + name="cluster__legend", # needs to start with `cluster` for graphviz layout graph_attr=dict( label="Legend", - rank="same", + rank="same", # makes the legend perpendicular to the main DAG fontname="helvetica", ), ) @@ -260,9 +303,9 @@ def _get_legend(node_types): sorted_types = [ "config", "input", - "materializer", "function", "output", + "materializer", "override", "expand", "collect", @@ -276,12 +319,11 @@ def _get_legend(node_types): return legend_subgraph - rankdir = "LR" if left_to_right is False else "TB" # handle default values in nested dict digraph_attr = dict( comment=comment, graph_attr=dict( - rankdir=rankdir, + rankdir=orient, ranksep="0.4", compound="true", concentrate="true", @@ -302,8 +344,10 @@ def _get_legend(node_types): node_style = _get_node_style(node_type) # prefer having the conditions explicit for now since they rely on - # VisualizationNodeModifiers and node.Node.node_role - if modifiers := node_modifiers.get(n.name): + # heterogeneous VisualizationNodeModifiers and node.Node.node_role. + # Otherwise, it's difficult to manage seen nodes and the legend. + if node_modifiers.get(n.name): + modifiers = node_modifiers[n.name] if VisualizationNodeModifiers.IS_OUTPUT in modifiers: modifier_style = _get_function_modifier_style("output") node_style.update(**modifier_style) @@ -339,30 +383,36 @@ def _get_legend(node_types): # create edges input_sets = dict() for n in nodes: - input_nodes = set() + to_type = "collect" if n.node_role == node.NodeType.COLLECT else "" + input_nodes = set() for d in n.dependencies: if strictly_display_only_nodes_passed_in and d not in nodes: continue - node_type = _get_node_type(d) - if node_type == "input": + dependency_type = _get_node_type(d) + # input nodes and edges are gathered instead of drawn + # they are drawn later, see below + if dependency_type == "input": input_nodes.add(d) continue - edge_style = _get_edge_style(d) + from_type = "expand" if d.node_role == node.NodeType.EXPAND else "" + edge_style = _get_edge_style(from_type, to_type) digraph.edge(d.name, n.name, **edge_style) # skip input node creation if hide_inputs: continue + # draw input nodes if at least 1 exist if len(input_nodes) > 0: - input_node_name = f"{n.name}_inputs" + input_node_name = f"_{n.name}_inputs" # following block is for input node deduplication input_nodes = frozenset(input_nodes) - if existing_input_name := input_sets.get(input_nodes): + if input_sets.get(input_nodes): + existing_input_name = input_sets[input_nodes] digraph.edge(existing_input_name, n.name) continue @@ -433,7 +483,9 @@ def __init__( @staticmethod def from_modules( - *modules: ModuleType, config: Dict[str, Any], adapter: base.HamiltonGraphAdapter = None + *modules: ModuleType, + config: Dict[str, Any], + adapter: base.HamiltonGraphAdapter = None, ): """Initializes a function graph from the specified modules. Note that this was the old way we constructed FunctionGraph -- this is not a public-facing API, so we replaced it @@ -668,7 +720,9 @@ def nodes_between(self, start: str, end: str) -> Set[node.Node]: return set(([start_node] if start_node is not None else []) + between + [end_node]) def directional_dfs_traverse( - self, next_nodes_fn: Callable[[node.Node], Collection[node.Node]], starting_nodes: List[str] + self, + next_nodes_fn: Callable[[node.Node], Collection[node.Node]], + starting_nodes: List[str], ): """Traverses the DAG directionally using a DFS. From 400964b45884534d30f25006d34d3b70f2cfd5fd Mon Sep 17 00:00:00 2001 From: zilto Date: Thu, 2 Nov 2023 11:11:11 -0400 Subject: [PATCH 03/10] added parameter to show/hide legend --- hamilton/graph.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/hamilton/graph.py b/hamilton/graph.py index 6943d2c10..c8dd6122d 100644 --- a/hamilton/graph.py +++ b/hamilton/graph.py @@ -140,6 +140,7 @@ def create_graphviz_graph( graphviz_kwargs: dict, node_modifiers: Dict[str, Set[VisualizationNodeModifiers]], strictly_display_only_nodes_passed_in: bool, + show_legend: bool = True, orient: str = "LR", hide_inputs: bool = False, deduplicate_inputs: bool = False, @@ -153,6 +154,7 @@ def create_graphviz_graph( :param node_modifiers: A dictionary of node names to dictionaries of node attributes to modify. :param strictly_display_only_nodes_passed_in: If True, only display the nodes passed in. Else defaults to displaying also what nodes a node depends on (i.e. all nodes that feed into it). + :param show_legend: If True, add a legend to the visualization based on the DAG's nodes. :param orient: `LR` stands for "left to right". Accepted values are TB, LR, BT, RL. `orient` will be overwridden by the value of `graphviz_kwargs['graph_attr']['rankdir']` see (https://graphviz.org/docs/attr-types/rankdir/) @@ -427,7 +429,8 @@ def _get_legend(node_types): # create edge for input node digraph.edge(input_node_name, n.name) - digraph.subgraph(_get_legend(seen_node_types)) + if show_legend: + digraph.subgraph(_get_legend(seen_node_types)) return digraph From 0d755f01a6ef4ddc5bc418a12eadfe62a7c9cf4d Mon Sep 17 00:00:00 2001 From: zilto Date: Thu, 2 Nov 2023 12:35:57 -0400 Subject: [PATCH 04/10] added tests for show_legend, orient, and hide_inputs --- tests/test_graph.py | 52 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/tests/test_graph.py b/tests/test_graph.py index 6bc420bdd..58ee6fa7c 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,4 +1,5 @@ import inspect +import pathlib import tempfile import uuid from itertools import permutations @@ -538,6 +539,57 @@ def test_function_graph_display(): assert dot_file is not None +@pytest.mark.parametrize("show_legend", [(True), (False)]) +def test_function_graph_display_legend(show_legend: bool, tmp_path: pathlib.Path): + dot_file_path = tmp_path / "dag.dot" + fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2}) + + fg.display( + set(fg.get_nodes()), + output_file_path=str(dot_file_path), + render_kwargs={"view": False}, + show_legend=show_legend, + ) + dot = dot_file_path.open("r").read() + + found_legend = "cluster__legend" in dot + assert found_legend is show_legend + + +@pytest.mark.parametrize("orient", [("LR"), ("TB"), ("RL"), ("BT")]) +def test_function_graph_display_orient(orient: str, tmp_path: pathlib.Path): + dot_file_path = tmp_path / "dag.dot" + fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2}) + + fg.display( + set(fg.get_nodes()), + output_file_path=str(dot_file_path), + render_kwargs={"view": False}, + orient=orient, + ) + dot = dot_file_path.open("r").read() + + # this could break if a rankdir is given to the legend subgraph + assert f"rankdir={orient}" in dot + + +@pytest.mark.parametrize("hide_inputs", [(True), (False)]) +def test_function_graph_display_inputs(hide_inputs: bool, tmp_path: pathlib.Path): + dot_file_path = tmp_path / "dag.dot" + fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2}) + + fg.display( + set(fg.get_nodes()), + output_file_path=str(dot_file_path), + render_kwargs={"view": False}, + hide_inputs=hide_inputs, + ) + dot_lines = dot_file_path.open("r").readlines() + + found_input = any(line.startswith("\t_") for line in dot_lines) + assert found_input is not hide_inputs + + def test_function_graph_display_without_saving(): """Tests that display works when None is passed in for path""" fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2}) From 92dc5e7a8e3424ca3267381d5aef6bbd0f2c29bc Mon Sep 17 00:00:00 2001 From: zilto Date: Thu, 2 Nov 2023 12:39:12 -0400 Subject: [PATCH 05/10] added parameters to --- hamilton/graph.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/hamilton/graph.py b/hamilton/graph.py index c8dd6122d..b1722a0ab 100644 --- a/hamilton/graph.py +++ b/hamilton/graph.py @@ -183,7 +183,7 @@ def _get_node_label( return f"<{name}

{type_}>" def _get_input_label(input_nodes: frozenset[node.Node]) -> str: - """Get a graphviz HTML-like node label formatted as a table. + """Get a graphviz HTML-like node label formatted aspyer a table. Each row is a different input node with one column containing the name and the other the type. ref: https://graphviz.org/doc/info/shapes.html#html @@ -600,6 +600,10 @@ def display( graphviz_kwargs: dict = None, node_modifiers: Dict[str, Set[VisualizationNodeModifiers]] = None, strictly_display_only_passed_in_nodes: bool = False, + show_legend: bool = True, + orient: str = "LR", + hide_inputs: bool = False, + deduplicate_inputs: bool = False, ) -> Optional["graphviz.Digraph"]: # noqa F821 """Function to display the graph represented by the passed in nodes. @@ -612,6 +616,13 @@ def display( e.g. {'node_name': {NodeModifiers.IS_USER_INPUT}} will set the node named 'node_name' to be a user input. :param strictly_display_only_passed_in_nodes: if True, only display the nodes passed in. Else defaults to displaying also what nodes a node depends on (i.e. all nodes that feed into it). + :param show_legend: If True, add a legend to the visualization based on the DAG's nodes. + :param orient: `LR` stands for "left to right". Accepted values are TB, LR, BT, RL. + `orient` will be overwridden by the value of `graphviz_kwargs['graph_attr']['rankdir']` + see (https://graphviz.org/docs/attr-types/rankdir/) + :param hide_inputs: If True, no input nodes are displayed. + :param deduplicate_inputs: If True, remove duplicate input nodes. + Can improve readability depending on the specifics of the DAG. :return: the graphviz graph object if it was created. None if not. """ # Check to see if optional dependencies have been installed. @@ -633,6 +644,10 @@ def display( graphviz_kwargs, node_modifiers, strictly_display_only_passed_in_nodes, + show_legend, + orient, + hide_inputs, + deduplicate_inputs, ) kwargs = {"view": True} if render_kwargs and isinstance(render_kwargs, dict): From 34539f173bb6cff6f956c10894f21ab4b3deb4ba Mon Sep 17 00:00:00 2001 From: zilto Date: Thu, 2 Nov 2023 13:36:28 -0400 Subject: [PATCH 06/10] fixed failing tests for graph.py --- tests/test_graph.py | 100 ++++++++++++++++++++++++++++---------------- 1 file changed, 64 insertions(+), 36 deletions(-) diff --git a/tests/test_graph.py b/tests/test_graph.py index 58ee6fa7c..4428b6414 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,6 +1,5 @@ import inspect import pathlib -import tempfile import uuid from itertools import permutations @@ -21,6 +20,7 @@ import tests.resources.optional_dependencies import tests.resources.parametrized_inputs import tests.resources.parametrized_nodes +import tests.resources.test_default_args import tests.resources.typing_vs_not_typing from hamilton import ad_hoc_utils, base, graph, node from hamilton.execution import graph_functions @@ -502,8 +502,9 @@ def test_function_graph_has_cycles_false(): assert fg.has_cycles(nodes, user_nodes) is False -def test_function_graph_display(): +def test_function_graph_display(tmp_path: pathlib.Path): """Tests that display saves a file""" + dot_file_path = tmp_path / "dag.dot" fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2}) node_modifiers = {"B": {graph.VisualizationNodeModifiers.IS_OUTPUT}} all_nodes = set() @@ -513,30 +514,49 @@ def test_function_graph_display(): all_nodes.add(n) # hack of a test -- but it works... sort the lines and match them up. # why? because for some reason given the same graph, the output file isn't deterministic. - expected = sorted( + # for the same reason, order of input nodes are non-deterministic + expected_set = set( [ + '\t\tfunction [fillcolor="#b4d8e4" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]\n', + "\t\tgraph [fontname=helvetica label=Legend rank=same]\n", + "\t\tinput [fontname=Helvetica margin=0.15 shape=rectangle style=dashed]\n", + '\t\toutput [fillcolor="#FFC857" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]\n', + "\tA -> B\n", + "\tA -> C\n", + '\tA [label=<A

int> fillcolor="#b4d8e4" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]\n', + '\tB [label=<B

int> fillcolor="#FFC857" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]\n', + '\tC [label=<C

int> fillcolor="#b4d8e4" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]\n', + "\t_A_inputs -> A\n", + # commenting out input node: '\t_A_inputs [label=<
cint
bint
> fontname=Helvetica margin=0.15 shape=rectangle style=dashed]\n', + "\tgraph [compound=true concentrate=true rankdir=LR ranksep=0.4]\n", + "\tsubgraph cluster__legend {\n", + "\t}\n", "// Dependency Graph\n", "digraph {\n", - "\tA [label=A]\n", - "\tC [label=C]\n", - "\tB [label=B shape=rectangle]\n", - '\tc [label="Input: c" style=dashed]\n', - '\tb [label="Input: b" style=dashed]\n', - "\tb -> A\n", - "\tc -> A\n", - "\tA -> C\n", - "\tA -> B\n", "}\n", ] ) - with tempfile.TemporaryDirectory() as tmp_dir: - path = tmp_dir.join("test.dot") - fg.display(all_nodes, str(path), {"view": False}, None, node_modifiers) - with open(str(path), "r") as dot_file: - actual = sorted(dot_file.readlines()) - assert actual == expected - dot_file = fg.display(all_nodes, output_file_path=None, node_modifiers=node_modifiers) - assert dot_file is not None + + fg.display( + all_nodes, + output_file_path=str(dot_file_path), + render_kwargs={"view": False}, + node_modifiers=node_modifiers, + ) + + dot = dot_file_path.open("r").readlines() + dot_set = set(dot) + + assert dot_set.issuperset(expected_set) and len(dot_set.difference(expected_set)) == 1 + + +def test_function_graph_display_no_dot_output(tmp_path: pathlib.Path): + dot_file_path = tmp_path / "dag.dot" + fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2}) + + fg.display(set(fg.get_nodes()), output_file_path=None) + + assert not dot_file_path.exists() @pytest.mark.parametrize("show_legend", [(True), (False)]) @@ -618,33 +638,41 @@ def test_create_graphviz_graph(): } # hack of a test -- but it works... sort the lines and match them up. # why? because for some reason given the same graph, the output file isn't deterministic. - expected = sorted( + # for the same reason, order of input nodes are non-deterministic + expected_set = set( [ - "// test-graph", + "// Dependency Graph", + "", "digraph {", "\tgraph [ratio=1]", - "\tB [label=B shape=rectangle]", - "\tA [label=A]", - "\tC [label=C]", - '\tb [label="Input: b" style=dashed]', - '\tc [label="Input: c" style=dashed]', + '\tB [label=<B

int> fillcolor="#FFC857" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]', + '\tC [label=<C

int> fillcolor="#b4d8e4" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]', + '\tA [label=<A

int> fillcolor="#b4d8e4" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]', "\tA -> B", - "\tb -> A", - "\tc -> A", "\tA -> C", + # commenting out input node: '\t_A_inputs [label=<
cint
bint
> fontname=Helvetica margin=0.15 shape=rectangle style=dashed]', + "\t_A_inputs -> A", + "\tsubgraph cluster__legend {", + "\t\tgraph [fontname=helvetica label=Legend rank=same]", + "\t\tinput [fontname=Helvetica margin=0.15 shape=rectangle style=dashed]", + '\t\tfunction [fillcolor="#b4d8e4" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]', + '\t\toutput [fillcolor="#FFC857" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]', + "\t}", "}", "", ] ) - if "" in expected: - expected.remove("") + digraph = graph.create_graphviz_graph( - nodez, "test-graph", dict(graph_attr={"ratio": "1"}), node_modifiers, False + nodez, + "Dependency Graph\n", + graphviz_kwargs=dict(graph_attr={"ratio": "1"}), + node_modifiers=node_modifiers, + strictly_display_only_nodes_passed_in=False, ) - actual = sorted(str(digraph).split("\n")) - if "" in actual: - actual.remove("") - assert actual == expected + dot_set = set(str(digraph).split("\n")) + + assert dot_set.issuperset(expected_set) and len(dot_set.difference(expected_set)) == 1 def test_create_networkx_graph(): From 289676fb43d877104cc13056a9e95709668ba647 Mon Sep 17 00:00:00 2001 From: zilto Date: Thu, 2 Nov 2023 13:49:07 -0400 Subject: [PATCH 07/10] removed type hint to pass 3.7 and 3.8 tests --- tests/test_graph.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/test_graph.py b/tests/test_graph.py index 4428b6414..2c1b12d0a 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,5 +1,4 @@ import inspect -import pathlib import uuid from itertools import permutations @@ -502,7 +501,7 @@ def test_function_graph_has_cycles_false(): assert fg.has_cycles(nodes, user_nodes) is False -def test_function_graph_display(tmp_path: pathlib.Path): +def test_function_graph_display(tmp_path): """Tests that display saves a file""" dot_file_path = tmp_path / "dag.dot" fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2}) @@ -550,7 +549,7 @@ def test_function_graph_display(tmp_path: pathlib.Path): assert dot_set.issuperset(expected_set) and len(dot_set.difference(expected_set)) == 1 -def test_function_graph_display_no_dot_output(tmp_path: pathlib.Path): +def test_function_graph_display_no_dot_output(tmp_path): dot_file_path = tmp_path / "dag.dot" fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2}) @@ -560,7 +559,7 @@ def test_function_graph_display_no_dot_output(tmp_path: pathlib.Path): @pytest.mark.parametrize("show_legend", [(True), (False)]) -def test_function_graph_display_legend(show_legend: bool, tmp_path: pathlib.Path): +def test_function_graph_display_legend(show_legend, tmp_path): dot_file_path = tmp_path / "dag.dot" fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2}) @@ -577,7 +576,7 @@ def test_function_graph_display_legend(show_legend: bool, tmp_path: pathlib.Path @pytest.mark.parametrize("orient", [("LR"), ("TB"), ("RL"), ("BT")]) -def test_function_graph_display_orient(orient: str, tmp_path: pathlib.Path): +def test_function_graph_display_orient(orient, tmp_path): dot_file_path = tmp_path / "dag.dot" fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2}) @@ -594,7 +593,7 @@ def test_function_graph_display_orient(orient: str, tmp_path: pathlib.Path): @pytest.mark.parametrize("hide_inputs", [(True), (False)]) -def test_function_graph_display_inputs(hide_inputs: bool, tmp_path: pathlib.Path): +def test_function_graph_display_inputs(hide_inputs, tmp_path): dot_file_path = tmp_path / "dag.dot" fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2}) From 4f05becbfc3c90b72a966bf0ec99dd0587e63d73 Mon Sep 17 00:00:00 2001 From: zilto Date: Thu, 2 Nov 2023 16:13:44 -0400 Subject: [PATCH 08/10] fixed Python 3.7 and 3.8 type hints in graph.py --- hamilton/graph.py | 6 +++--- tests/test_graph.py | 11 ++++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/hamilton/graph.py b/hamilton/graph.py index b1722a0ab..200540226 100644 --- a/hamilton/graph.py +++ b/hamilton/graph.py @@ -8,7 +8,7 @@ import logging from enum import Enum from types import ModuleType -from typing import Any, Callable, Collection, Dict, List, Optional, Set, Tuple, Type +from typing import Any, Callable, Collection, Dict, FrozenSet, List, Optional, Set, Tuple, Type from hamilton import base, node from hamilton.execution import graph_functions @@ -182,7 +182,7 @@ def _get_node_label( type_ = n.type.__name__ if type_ is None else type_ return f"<{name}

{type_}>" - def _get_input_label(input_nodes: frozenset[node.Node]) -> str: + def _get_input_label(input_nodes: FrozenSet[node.Node]) -> str: """Get a graphviz HTML-like node label formatted aspyer a table. Each row is a different input node with one column containing the name and the other the type. @@ -289,7 +289,7 @@ def _get_edge_style(from_type: str, to_type: str) -> dict: return edge_style - def _get_legend(node_types): + def _get_legend(node_types: Set[str]): """Create a visualization legend as a graphviz subgraph. The legend includes the node types and modifiers presente in the visualization. """ diff --git a/tests/test_graph.py b/tests/test_graph.py index 2c1b12d0a..4428b6414 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,4 +1,5 @@ import inspect +import pathlib import uuid from itertools import permutations @@ -501,7 +502,7 @@ def test_function_graph_has_cycles_false(): assert fg.has_cycles(nodes, user_nodes) is False -def test_function_graph_display(tmp_path): +def test_function_graph_display(tmp_path: pathlib.Path): """Tests that display saves a file""" dot_file_path = tmp_path / "dag.dot" fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2}) @@ -549,7 +550,7 @@ def test_function_graph_display(tmp_path): assert dot_set.issuperset(expected_set) and len(dot_set.difference(expected_set)) == 1 -def test_function_graph_display_no_dot_output(tmp_path): +def test_function_graph_display_no_dot_output(tmp_path: pathlib.Path): dot_file_path = tmp_path / "dag.dot" fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2}) @@ -559,7 +560,7 @@ def test_function_graph_display_no_dot_output(tmp_path): @pytest.mark.parametrize("show_legend", [(True), (False)]) -def test_function_graph_display_legend(show_legend, tmp_path): +def test_function_graph_display_legend(show_legend: bool, tmp_path: pathlib.Path): dot_file_path = tmp_path / "dag.dot" fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2}) @@ -576,7 +577,7 @@ def test_function_graph_display_legend(show_legend, tmp_path): @pytest.mark.parametrize("orient", [("LR"), ("TB"), ("RL"), ("BT")]) -def test_function_graph_display_orient(orient, tmp_path): +def test_function_graph_display_orient(orient: str, tmp_path: pathlib.Path): dot_file_path = tmp_path / "dag.dot" fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2}) @@ -593,7 +594,7 @@ def test_function_graph_display_orient(orient, tmp_path): @pytest.mark.parametrize("hide_inputs", [(True), (False)]) -def test_function_graph_display_inputs(hide_inputs, tmp_path): +def test_function_graph_display_inputs(hide_inputs: bool, tmp_path: pathlib.Path): dot_file_path = tmp_path / "dag.dot" fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2}) From 353651815c5971c682857400b44a8442a062b5e9 Mon Sep 17 00:00:00 2001 From: zilto Date: Thu, 2 Nov 2023 17:15:40 -0400 Subject: [PATCH 09/10] yet another type annotation --- hamilton/graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hamilton/graph.py b/hamilton/graph.py index 200540226..1ef79bc1a 100644 --- a/hamilton/graph.py +++ b/hamilton/graph.py @@ -213,7 +213,7 @@ def _get_node_type(n: node.Node) -> str: else: return "function" - def _get_node_style(node_type: str) -> dict[str, str]: + def _get_node_style(node_type: str) -> Dict[str, str]: """Get the style of a node type. Graphviz needs values to be strings. """ From 775dbf210dc379f282fb10cf4f58a8d72c7c1f4f Mon Sep 17 00:00:00 2001 From: zilto Date: Thu, 2 Nov 2023 17:43:46 -0400 Subject: [PATCH 10/10] yet yet another type annotation for 3.7 --- hamilton/graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hamilton/graph.py b/hamilton/graph.py index 1ef79bc1a..557a65a1d 100644 --- a/hamilton/graph.py +++ b/hamilton/graph.py @@ -249,7 +249,7 @@ def _get_node_style(node_type: str) -> Dict[str, str]: return node_style - def _get_function_modifier_style(modifier: str) -> dict[str, str]: + def _get_function_modifier_style(modifier: str) -> Dict[str, str]: """Get the style of a modifier. The dictionary returned is used to overwrite values of the base node style. Graphviz needs values to be strings. @@ -269,7 +269,7 @@ def _get_function_modifier_style(modifier: str) -> dict[str, str]: return modifier_style - def _get_edge_style(from_type: str, to_type: str) -> dict: + def _get_edge_style(from_type: str, to_type: str) -> Dict: """ Graphviz needs values to be strings.