From 54b2a4f8fcec2d6485af43b6c1c4f27f1415a8fb Mon Sep 17 00:00:00 2001 From: Thierry Jean <68975210+zilto@users.noreply.github.com> Date: Sun, 5 May 2024 18:38:04 -0400 Subject: [PATCH] Remove temporary DOT files (#884) * Prevent temporary DOT file Hamilton visualizations rely on the graphviz library. It defines graphs using the DOT language, which defines one statement per line using a string. Previously, Hamilton used `graphviz.Digraph.render()` to produce visualizations. This has the side-effect of producing an intermediary DOT file on disk. This is most often of no use and clutters the directory. Now, we are switching to `graphviz.Digraph.pipe()` to write bytes directly to an open file. Tests were updated accordingly. The keyword argument `keep_dot` was added to viz functions in case users still want this DOT file to be produced. It allows to rerender the viz with a different style without re-executing the Hamilton code. It could be useful when iterating over custom styling. * added keep_dot to viz functions * added back view kwarg; fixed typing for 3.8 --------- Co-authored-by: zilto --- hamilton/driver.py | 18 +++++++++ hamilton/graph.py | 16 ++++++-- tests/test_graph.py | 93 ++++++++++++++++----------------------------- 3 files changed, 63 insertions(+), 64 deletions(-) diff --git a/hamilton/driver.py b/hamilton/driver.py index a24312f83..2a53eb268 100644 --- a/hamilton/driver.py +++ b/hamilton/driver.py @@ -746,6 +746,7 @@ def display_all_functions( deduplicate_inputs: bool = False, show_schema: bool = True, custom_style_function: Callable = None, + keep_dot: bool = False, ) -> Optional["graphviz.Digraph"]: # noqa F821 """Displays the graph of all functions loaded! @@ -767,6 +768,7 @@ def display_all_functions( :param show_schema: If True, display the schema of the DAG if the nodes have schema data provided :param custom_style_function: Optional. Custom style function. See example in repository for example use. + :param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html) :return: the graphviz object if you want to do more with it. If returned as the result in a Jupyter Notebook cell, it will render. """ @@ -781,6 +783,7 @@ def display_all_functions( deduplicate_inputs=deduplicate_inputs, display_fields=show_schema, custom_style_function=custom_style_function, + keep_dot=keep_dot, ) except ImportError as e: logger.warning(f"Unable to import {e}", exc_info=True) @@ -802,6 +805,7 @@ def _visualize_execution_helper( show_schema: bool = True, custom_style_function: Callable = None, bypass_validation: bool = False, + keep_dot: bool = False, ): """Helper function to visualize execution, using a passed-in function graph. @@ -816,6 +820,7 @@ def _visualize_execution_helper( :param deduplicate_inputs: If True, remove duplicate input nodes. :param show_schema: If True, display the schema of the DAG if nodes have schema data provided :param custom_style_function: Optional. Custom style function. + :param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html) :return: the graphviz object if you want to do more with it. """ # TODO should determine if the visualization logic should live here or in the graph.py module @@ -851,6 +856,7 @@ def _visualize_execution_helper( display_fields=show_schema, custom_style_function=custom_style_function, config=fn_graph._config, + keep_dot=keep_dot, ) except ImportError as e: logger.warning(f"Unable to import {e}", exc_info=True) @@ -871,6 +877,7 @@ def visualize_execution( show_schema: bool = True, custom_style_function: Callable = None, bypass_validation: bool = False, + keep_dot: bool = False, ) -> Optional["graphviz.Digraph"]: # noqa F821 """Visualizes Execution. @@ -902,6 +909,7 @@ def visualize_execution( Can improve readability depending on the specifics of the DAG. :param show_schema: If True, display the schema of the DAG if nodes have schema data provided :param custom_style_function: Optional. Custom style function. + :param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html) :return: the graphviz object if you want to do more with it. If returned as the result in a Jupyter Notebook cell, it will render. """ @@ -922,6 +930,7 @@ def visualize_execution( show_schema=show_schema, custom_style_function=custom_style_function, bypass_validation=bypass_validation, + keep_dot=keep_dot, ) @capture_function_usage @@ -988,6 +997,7 @@ def display_downstream_of( deduplicate_inputs: bool = False, show_schema: bool = True, custom_style_function: Callable = None, + keep_dot: bool = False, ) -> Optional["graphviz.Digraph"]: # noqa F821 """Creates a visualization of the DAG starting from the passed in function name(s). @@ -1010,6 +1020,7 @@ def display_downstream_of( Can improve readability depending on the specifics of the DAG. :param show_schema: If True, display the schema of the DAG if nodes have schema data provided :param custom_style_function: Optional. Custom style function. + :param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html) :return: the graphviz object if you want to do more with it. If returned as the result in a Jupyter Notebook cell, it will render. """ @@ -1054,6 +1065,7 @@ def display_upstream_of( deduplicate_inputs: bool = False, show_schema: bool = True, custom_style_function: Callable = None, + keep_dot: bool = False, ) -> Optional["graphviz.Digraph"]: # noqa F821 """Creates a visualization of the DAG going backwards from the passed in function name(s). @@ -1076,6 +1088,7 @@ def display_upstream_of( Can improve readability depending on the specifics of the DAG. :param show_schema: If True, display the schema of the DAG if nodes have schema data provided :param custom_style_function: Optional. Custom style function. + :param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html) :return: the graphviz object if you want to do more with it. If returned as the result in a Jupyter Notebook cell, it will render. """ @@ -1172,6 +1185,7 @@ def visualize_path_between( deduplicate_inputs: bool = False, show_schema: bool = True, custom_style_function: Callable = None, + keep_dot: bool = False, ) -> Optional["graphviz.Digraph"]: # noqa F821 """Visualizes the path between two nodes. @@ -1197,6 +1211,7 @@ def visualize_path_between( :param show_schema: If True, display the schema of the DAG if nodes have schema data provided :return: graphviz object. :param custom_style_function: Optional. Custom style function. + :param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html) :raise ValueError: if the upstream or downstream node names are not found in the graph, or there is no path between them. """ @@ -1256,6 +1271,7 @@ def visualize_path_between( display_fields=show_schema, custom_style_function=custom_style_function, config=self.graph._config, + keep_dot=keep_dot, ) except ImportError as e: logger.warning(f"Unable to import {e}", exc_info=True) @@ -1523,6 +1539,7 @@ def visualize_materialization( show_schema: bool = True, custom_style_function: Callable = None, bypass_validation: bool = False, + keep_dot: bool = False, ) -> Optional["graphviz.Digraph"]: # noqa F821 """Visualizes materialization. This helps give you a sense of how materialization will impact the DAG. @@ -1572,6 +1589,7 @@ def visualize_materialization( show_schema=show_schema, custom_style_function=custom_style_function, bypass_validation=bypass_validation, + keep_dot=keep_dot, ) def validate_execution( diff --git a/hamilton/graph.py b/hamilton/graph.py index 328c82b05..cc66b5b1a 100644 --- a/hamilton/graph.py +++ b/hamilton/graph.py @@ -9,6 +9,7 @@ import inspect import logging import os.path +import pathlib import uuid from enum import Enum from types import ModuleType @@ -747,6 +748,7 @@ def display_all( deduplicate_inputs: bool = False, display_fields: bool = True, custom_style_function: Callable = None, + keep_dot: bool = False, ) -> Optional["graphviz.Digraph"]: # noqa F821 """Displays & saves a dot file of the entire DAG structure constructed. @@ -764,6 +766,7 @@ def display_all( Can improve readability depending on the specifics of the DAG. :param display_fields: If True, display fields in the graph if node has attached schema metadata :param custom_style_function: Optional. Custom style function. + :param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html) :return: the graphviz graph object if it was created. None if not. """ all_nodes = set() @@ -789,6 +792,7 @@ def display_all( display_fields=display_fields, custom_style_function=custom_style_function, config=self._config, + keep_dot=keep_dot, ) def has_cycles(self, nodes: Set[node.Node], user_nodes: Set[node.Node]) -> bool: @@ -823,7 +827,7 @@ def get_cycles(self, nodes: Set[node.Node], user_nodes: Set[node.Node]) -> List[ @staticmethod def display( nodes: Set[node.Node], - output_file_path: Optional[str] = "test-output/graph.gv", + output_file_path: Optional[str] = None, render_kwargs: dict = None, graphviz_kwargs: dict = None, node_modifiers: Dict[str, Set[VisualizationNodeModifiers]] = None, @@ -835,6 +839,7 @@ def display( display_fields: bool = True, custom_style_function: Callable = None, config: dict = None, + keep_dot: bool = False, ) -> Optional["graphviz.Digraph"]: # noqa F821 """Function to display the graph represented by the passed in nodes. @@ -894,7 +899,7 @@ def display( custom_style_function=custom_style_function, config=config, ) - kwargs = {"view": False, "format": "png"} # default format = png + kwargs = {"format": "png"} # default format = png if output_file_path: # infer format from path output_file_path, suffix = os.path.splitext(output_file_path) if suffix != "": @@ -903,7 +908,12 @@ def display( if render_kwargs and isinstance(render_kwargs, dict): # accept explicit format kwargs.update(render_kwargs) if output_file_path: - dot.render(output_file_path, **kwargs) + if keep_dot: + kwargs["view"] = kwargs.get("view", False) + dot.render(output_file_path, **kwargs) + else: + kwargs.pop("view", None) + pathlib.Path(output_file_path).write_bytes(dot.pipe(**kwargs)) return dot def get_impacted_nodes(self, var_changes: List[str]) -> Set[node.Node]: diff --git a/tests/test_graph.py b/tests/test_graph.py index f81d879f6..bc07fd0a4 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,8 +1,8 @@ import inspect -import os import pathlib import uuid from itertools import permutations +from typing import List import pandas as pd import pytest @@ -727,12 +727,12 @@ def test_function_graph_display(tmp_path: pathlib.Path): fg.display( all_nodes, output_file_path=str(dot_file_path), - render_kwargs={"view": False}, node_modifiers=node_modifiers, config=config, + keep_dot=True, ) - dot = dot_file_path.open("r").readlines() - dot_set = set(dot) + dot_file = dot_file_path.open("r").readlines() + dot_set = set(dot_file) assert dot_set.issuperset(expected_set) and len(dot_set.difference(expected_set)) == 1 @@ -756,7 +756,6 @@ def _styling_function(*, node, node_class): digraph = fg.display( set(fg.get_nodes()), - output_file_path=None, custom_style_function=_styling_function, config=config, ) @@ -779,7 +778,6 @@ def _styling_function(*, node, node_class): digraph = fg.display( set(fg.get_nodes()), - output_file_path=None, custom_style_function=_styling_function, config=config, ) @@ -808,7 +806,6 @@ def _styling_function(*, node, node_class): digraph = fg.display( set(fg.get_nodes()), - output_file_path=None, custom_style_function=_styling_function, config=config, ) @@ -826,60 +823,47 @@ def _styling_function(*, node, node_class): @pytest.mark.parametrize("show_legend", [(True), (False)]) -def test_function_graph_display_legend(show_legend: bool, tmp_path: pathlib.Path): - dot_file_path = tmp_path / "dag.png" +def test_function_graph_display_legend(show_legend: bool): config = {"b": 1, "c": 2} fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config=config) - fg.display( + dot = fg.display( set(fg.get_nodes()), - output_file_path=str(dot_file_path), - render_kwargs={"view": False}, show_legend=show_legend, config=config, ) - dot_file = pathlib.Path(os.path.splitext(str(dot_file_path))[0]) - dot = dot_file.open("r").read() - found_legend = "cluster__legend" in dot + found_legend = "cluster__legend" in dot.source assert found_legend is show_legend @pytest.mark.parametrize("orient", [("LR"), ("TB"), ("RL"), ("BT")]) -def test_function_graph_display_orient(orient: str, tmp_path: pathlib.Path): - dot_file_path = tmp_path / "dag" +def test_function_graph_display_orient(orient: str): config = {"b": 1, "c": 2} fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config=config) - fg.display( + dot = fg.display( set(fg.get_nodes()), - output_file_path=str(dot_file_path), - render_kwargs={"view": False}, orient=orient, config=config, ) - dot = dot_file_path.open("r").read() # this could break if a rankdir is given to the legend subgraph - assert f"rankdir={orient}" in dot + assert f"rankdir={orient}" in dot.source @pytest.mark.parametrize("hide_inputs", [(True,), (False,)]) -def test_function_graph_display_inputs(hide_inputs: bool, tmp_path: pathlib.Path): - dot_file_path = tmp_path / "dag" +def test_function_graph_display_inputs(hide_inputs: bool): config = {"b": 1, "c": 2} fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config=config) - fg.display( + dot = fg.display( set(fg.get_nodes()), - output_file_path=str(dot_file_path), - render_kwargs={"view": False}, hide_inputs=hide_inputs, config=config, ) - dot_lines = dot_file_path.open("r").readlines() - found_input = any(line.startswith("\t_") for line in dot_lines) + found_input = any(line.startswith("\t_") for line in dot.body) assert found_input is not hide_inputs @@ -903,9 +887,7 @@ def test_function_graph_display_without_saving(): @pytest.mark.parametrize("display_fields", [(True,), (False,)]) -def test_function_graph_display_fields(display_fields: bool, tmp_path: pathlib.Path): - dot_file_path = tmp_path / "dag" - +def test_function_graph_display_fields(display_fields: bool): @schema.output(("foo", "int"), ("bar", "float"), ("baz", "str")) def df_with_schema() -> pd.DataFrame: pass @@ -914,30 +896,25 @@ def df_with_schema() -> pd.DataFrame: config = {} fg = graph.FunctionGraph.from_modules(mod, config=config) - fg.display( + dot = fg.display( set(fg.get_nodes()), - output_file_path=str(dot_file_path), - render_kwargs={"view": False}, display_fields=display_fields, config=config, ) - dot_lines = dot_file_path.open("r").readlines() if display_fields: - assert any("foo" in line for line in dot_lines) - assert any("bar" in line for line in dot_lines) - assert any("baz" in line for line in dot_lines) - assert any("cluster" in line for line in dot_lines) + assert any("foo" in line for line in dot.body) + assert any("bar" in line for line in dot.body) + assert any("baz" in line for line in dot.body) + assert any("cluster" in line for line in dot.body) else: - assert not any("foo" in line for line in dot_lines) - assert not any("bar" in line for line in dot_lines) - assert not any("baz" in line for line in dot_lines) - assert not any("cluster" in line for line in dot_lines) + assert not any("foo" in line for line in dot.body) + assert not any("bar" in line for line in dot.body) + assert not any("baz" in line for line in dot.body) + assert not any("cluster" in line for line in dot.body) -def test_function_graph_display_fields_shared_schema(tmp_path: pathlib.Path): +def test_function_graph_display_fields_shared_schema(): # This ensures an edge case where they end up getting dropped if there are duplicates - dot_file_path = tmp_path / "dag" - SCHEMA = (("foo", "int"), ("bar", "float"), ("baz", "str")) @schema.output(*SCHEMA) @@ -952,22 +929,19 @@ def df_2_with_schema() -> pd.DataFrame: config = {} fg = graph.FunctionGraph.from_modules(mod, config=config) - fg.display( + dot = fg.display( set(fg.get_nodes()), - output_file_path=str(dot_file_path), - render_kwargs={"view": False}, display_fields=True, config=config, ) - dot_lines = dot_file_path.open("r").readlines() - def _get_occurances(var: str): - return [item for item in dot_lines if var in item] + def _get_occurances(var: str, lines: List[str]): + return [item for item in lines if var in item] # We just need to make sure these show up twice - assert len(_get_occurances("foo=")) == 2 - assert len(_get_occurances("bar=")) == 2 - assert len(_get_occurances("baz=")) == 2 + assert len(_get_occurances("foo=", dot.body)) == 2 + assert len(_get_occurances("bar=", dot.body)) == 2 + assert len(_get_occurances("baz=", dot.body)) == 2 def test_function_graph_display_config_node(): @@ -977,13 +951,12 @@ def test_function_graph_display_config_node(): dot = fg.display(set(fg.get_nodes()), config=config) - # dot.body is a list of string # lines start tab then node name; check if "b" is a node in the graphviz object assert any(line.startswith("\tX") for line in dot.body) # TODO use high-level visualization dot as fixtures for reuse across tests -def test_display_config_node(tmp_path: pathlib.Path): +def test_display_config_node(): """Check if config is displayed by high-level hamilton.driver.display...""" from hamilton import driver from hamilton.io.materialization import to @@ -997,9 +970,7 @@ def test_display_config_node(tmp_path: pathlib.Path): between_dot = dr.visualize_path_between("A", "C") exec_dot = dr.visualize_execution(["C"], inputs={"b": 1, "c": 2}) materialize_dot = dr.visualize_materialization( - to.json( - id="saver", dependencies=["C"], combine=base.DictResult(), path=f"{tmp_path}/saver.json" - ), + to.json(id="saver", dependencies=["C"], combine=base.DictResult(), path="saver.json"), inputs={"b": 1, "c": 2}, )