Skip to content

Commit

Permalink
Remove temporary DOT files (#884)
Browse files Browse the repository at this point in the history
* Prevent temporary DOT file

Hamilton visualizations rely on the graphviz library.
It defines graphs using the DOT language, which
defines one statement per line using a string.

Previously, Hamilton used `graphviz.Digraph.render()`
to produce visualizations. This has the side-effect
of producing an intermediary DOT file on disk. This
is most often of no use and clutters the directory.

Now, we are switching to `graphviz.Digraph.pipe()`
to write bytes directly to an open file. Tests were
updated accordingly.

The keyword argument `keep_dot` was added to viz
functions in case users still want this DOT file
to be produced. It allows to rerender the viz with
a different style without re-executing the Hamilton
code. It could be useful when iterating over custom
styling.

* added keep_dot to viz functions

* added back view kwarg; fixed typing for 3.8

---------

Co-authored-by: zilto <tjean@DESKTOP-V6JDCS2>
  • Loading branch information
zilto and zilto authored May 5, 2024
1 parent 821a317 commit 54b2a4f
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 64 deletions.
18 changes: 18 additions & 0 deletions hamilton/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,7 @@ def display_all_functions(
deduplicate_inputs: bool = False,
show_schema: bool = True,
custom_style_function: Callable = None,
keep_dot: bool = False,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Displays the graph of all functions loaded!
Expand All @@ -767,6 +768,7 @@ def display_all_functions(
:param show_schema: If True, display the schema of the DAG if
the nodes have schema data provided
:param custom_style_function: Optional. Custom style function. See example in repository for example use.
:param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html)
:return: the graphviz object if you want to do more with it.
If returned as the result in a Jupyter Notebook cell, it will render.
"""
Expand All @@ -781,6 +783,7 @@ def display_all_functions(
deduplicate_inputs=deduplicate_inputs,
display_fields=show_schema,
custom_style_function=custom_style_function,
keep_dot=keep_dot,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
Expand All @@ -802,6 +805,7 @@ def _visualize_execution_helper(
show_schema: bool = True,
custom_style_function: Callable = None,
bypass_validation: bool = False,
keep_dot: bool = False,
):
"""Helper function to visualize execution, using a passed-in function graph.
Expand All @@ -816,6 +820,7 @@ def _visualize_execution_helper(
:param deduplicate_inputs: If True, remove duplicate input nodes.
:param show_schema: If True, display the schema of the DAG if nodes have schema data provided
:param custom_style_function: Optional. Custom style function.
:param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html)
:return: the graphviz object if you want to do more with it.
"""
# TODO should determine if the visualization logic should live here or in the graph.py module
Expand Down Expand Up @@ -851,6 +856,7 @@ def _visualize_execution_helper(
display_fields=show_schema,
custom_style_function=custom_style_function,
config=fn_graph._config,
keep_dot=keep_dot,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
Expand All @@ -871,6 +877,7 @@ def visualize_execution(
show_schema: bool = True,
custom_style_function: Callable = None,
bypass_validation: bool = False,
keep_dot: bool = False,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Visualizes Execution.
Expand Down Expand Up @@ -902,6 +909,7 @@ def visualize_execution(
Can improve readability depending on the specifics of the DAG.
:param show_schema: If True, display the schema of the DAG if nodes have schema data provided
:param custom_style_function: Optional. Custom style function.
:param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html)
:return: the graphviz object if you want to do more with it.
If returned as the result in a Jupyter Notebook cell, it will render.
"""
Expand All @@ -922,6 +930,7 @@ def visualize_execution(
show_schema=show_schema,
custom_style_function=custom_style_function,
bypass_validation=bypass_validation,
keep_dot=keep_dot,
)

@capture_function_usage
Expand Down Expand Up @@ -988,6 +997,7 @@ def display_downstream_of(
deduplicate_inputs: bool = False,
show_schema: bool = True,
custom_style_function: Callable = None,
keep_dot: bool = False,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Creates a visualization of the DAG starting from the passed in function name(s).
Expand All @@ -1010,6 +1020,7 @@ def display_downstream_of(
Can improve readability depending on the specifics of the DAG.
:param show_schema: If True, display the schema of the DAG if nodes have schema data provided
:param custom_style_function: Optional. Custom style function.
:param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html)
:return: the graphviz object if you want to do more with it.
If returned as the result in a Jupyter Notebook cell, it will render.
"""
Expand Down Expand Up @@ -1054,6 +1065,7 @@ def display_upstream_of(
deduplicate_inputs: bool = False,
show_schema: bool = True,
custom_style_function: Callable = None,
keep_dot: bool = False,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Creates a visualization of the DAG going backwards from the passed in function name(s).
Expand All @@ -1076,6 +1088,7 @@ def display_upstream_of(
Can improve readability depending on the specifics of the DAG.
:param show_schema: If True, display the schema of the DAG if nodes have schema data provided
:param custom_style_function: Optional. Custom style function.
:param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html)
:return: the graphviz object if you want to do more with it.
If returned as the result in a Jupyter Notebook cell, it will render.
"""
Expand Down Expand Up @@ -1172,6 +1185,7 @@ def visualize_path_between(
deduplicate_inputs: bool = False,
show_schema: bool = True,
custom_style_function: Callable = None,
keep_dot: bool = False,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Visualizes the path between two nodes.
Expand All @@ -1197,6 +1211,7 @@ def visualize_path_between(
:param show_schema: If True, display the schema of the DAG if nodes have schema data provided
:return: graphviz object.
:param custom_style_function: Optional. Custom style function.
:param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html)
:raise ValueError: if the upstream or downstream node names are not found in the graph,
or there is no path between them.
"""
Expand Down Expand Up @@ -1256,6 +1271,7 @@ def visualize_path_between(
display_fields=show_schema,
custom_style_function=custom_style_function,
config=self.graph._config,
keep_dot=keep_dot,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
Expand Down Expand Up @@ -1523,6 +1539,7 @@ def visualize_materialization(
show_schema: bool = True,
custom_style_function: Callable = None,
bypass_validation: bool = False,
keep_dot: bool = False,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Visualizes materialization. This helps give you a sense of how materialization
will impact the DAG.
Expand Down Expand Up @@ -1572,6 +1589,7 @@ def visualize_materialization(
show_schema=show_schema,
custom_style_function=custom_style_function,
bypass_validation=bypass_validation,
keep_dot=keep_dot,
)

def validate_execution(
Expand Down
16 changes: 13 additions & 3 deletions hamilton/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import inspect
import logging
import os.path
import pathlib
import uuid
from enum import Enum
from types import ModuleType
Expand Down Expand Up @@ -747,6 +748,7 @@ def display_all(
deduplicate_inputs: bool = False,
display_fields: bool = True,
custom_style_function: Callable = None,
keep_dot: bool = False,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Displays & saves a dot file of the entire DAG structure constructed.
Expand All @@ -764,6 +766,7 @@ def display_all(
Can improve readability depending on the specifics of the DAG.
:param display_fields: If True, display fields in the graph if node has attached schema metadata
:param custom_style_function: Optional. Custom style function.
:param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html)
:return: the graphviz graph object if it was created. None if not.
"""
all_nodes = set()
Expand All @@ -789,6 +792,7 @@ def display_all(
display_fields=display_fields,
custom_style_function=custom_style_function,
config=self._config,
keep_dot=keep_dot,
)

def has_cycles(self, nodes: Set[node.Node], user_nodes: Set[node.Node]) -> bool:
Expand Down Expand Up @@ -823,7 +827,7 @@ def get_cycles(self, nodes: Set[node.Node], user_nodes: Set[node.Node]) -> List[
@staticmethod
def display(
nodes: Set[node.Node],
output_file_path: Optional[str] = "test-output/graph.gv",
output_file_path: Optional[str] = None,
render_kwargs: dict = None,
graphviz_kwargs: dict = None,
node_modifiers: Dict[str, Set[VisualizationNodeModifiers]] = None,
Expand All @@ -835,6 +839,7 @@ def display(
display_fields: bool = True,
custom_style_function: Callable = None,
config: dict = None,
keep_dot: bool = False,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Function to display the graph represented by the passed in nodes.
Expand Down Expand Up @@ -894,7 +899,7 @@ def display(
custom_style_function=custom_style_function,
config=config,
)
kwargs = {"view": False, "format": "png"} # default format = png
kwargs = {"format": "png"} # default format = png
if output_file_path: # infer format from path
output_file_path, suffix = os.path.splitext(output_file_path)
if suffix != "":
Expand All @@ -903,7 +908,12 @@ def display(
if render_kwargs and isinstance(render_kwargs, dict): # accept explicit format
kwargs.update(render_kwargs)
if output_file_path:
dot.render(output_file_path, **kwargs)
if keep_dot:
kwargs["view"] = kwargs.get("view", False)
dot.render(output_file_path, **kwargs)
else:
kwargs.pop("view", None)
pathlib.Path(output_file_path).write_bytes(dot.pipe(**kwargs))
return dot

def get_impacted_nodes(self, var_changes: List[str]) -> Set[node.Node]:
Expand Down
Loading

0 comments on commit 54b2a4f

Please sign in to comment.