Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/remove tmp dot #889

Merged
merged 4 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions hamilton/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,7 @@ def display_all_functions(
deduplicate_inputs: bool = False,
show_schema: bool = True,
custom_style_function: Callable = None,
keep_dot: bool = False,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Displays the graph of all functions loaded!

Expand All @@ -767,6 +768,7 @@ def display_all_functions(
:param show_schema: If True, display the schema of the DAG if
the nodes have schema data provided
:param custom_style_function: Optional. Custom style function. See example in repository for example use.
:param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html)
:return: the graphviz object if you want to do more with it.
If returned as the result in a Jupyter Notebook cell, it will render.
"""
Expand All @@ -781,6 +783,7 @@ def display_all_functions(
deduplicate_inputs=deduplicate_inputs,
display_fields=show_schema,
custom_style_function=custom_style_function,
keep_dot=keep_dot,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
Expand All @@ -802,6 +805,7 @@ def _visualize_execution_helper(
show_schema: bool = True,
custom_style_function: Callable = None,
bypass_validation: bool = False,
keep_dot: bool = False,
):
"""Helper function to visualize execution, using a passed-in function graph.

Expand All @@ -816,6 +820,7 @@ def _visualize_execution_helper(
:param deduplicate_inputs: If True, remove duplicate input nodes.
:param show_schema: If True, display the schema of the DAG if nodes have schema data provided
:param custom_style_function: Optional. Custom style function.
:param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html)
:return: the graphviz object if you want to do more with it.
"""
# TODO should determine if the visualization logic should live here or in the graph.py module
Expand Down Expand Up @@ -851,6 +856,7 @@ def _visualize_execution_helper(
display_fields=show_schema,
custom_style_function=custom_style_function,
config=fn_graph._config,
keep_dot=keep_dot,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
Expand All @@ -871,6 +877,7 @@ def visualize_execution(
show_schema: bool = True,
custom_style_function: Callable = None,
bypass_validation: bool = False,
keep_dot: bool = False,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Visualizes Execution.

Expand Down Expand Up @@ -902,6 +909,7 @@ def visualize_execution(
Can improve readability depending on the specifics of the DAG.
:param show_schema: If True, display the schema of the DAG if nodes have schema data provided
:param custom_style_function: Optional. Custom style function.
:param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html)
:return: the graphviz object if you want to do more with it.
If returned as the result in a Jupyter Notebook cell, it will render.
"""
Expand All @@ -922,6 +930,7 @@ def visualize_execution(
show_schema=show_schema,
custom_style_function=custom_style_function,
bypass_validation=bypass_validation,
keep_dot=keep_dot,
)

@capture_function_usage
Expand Down Expand Up @@ -988,6 +997,7 @@ def display_downstream_of(
deduplicate_inputs: bool = False,
show_schema: bool = True,
custom_style_function: Callable = None,
keep_dot: bool = False,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Creates a visualization of the DAG starting from the passed in function name(s).

Expand All @@ -1010,6 +1020,7 @@ def display_downstream_of(
Can improve readability depending on the specifics of the DAG.
:param show_schema: If True, display the schema of the DAG if nodes have schema data provided
:param custom_style_function: Optional. Custom style function.
:param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html)
:return: the graphviz object if you want to do more with it.
If returned as the result in a Jupyter Notebook cell, it will render.
"""
Expand Down Expand Up @@ -1054,6 +1065,7 @@ def display_upstream_of(
deduplicate_inputs: bool = False,
show_schema: bool = True,
custom_style_function: Callable = None,
keep_dot: bool = False,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Creates a visualization of the DAG going backwards from the passed in function name(s).

Expand All @@ -1076,6 +1088,7 @@ def display_upstream_of(
Can improve readability depending on the specifics of the DAG.
:param show_schema: If True, display the schema of the DAG if nodes have schema data provided
:param custom_style_function: Optional. Custom style function.
:param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html)
:return: the graphviz object if you want to do more with it.
If returned as the result in a Jupyter Notebook cell, it will render.
"""
Expand Down Expand Up @@ -1172,6 +1185,7 @@ def visualize_path_between(
deduplicate_inputs: bool = False,
show_schema: bool = True,
custom_style_function: Callable = None,
keep_dot: bool = False,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Visualizes the path between two nodes.

Expand All @@ -1197,6 +1211,7 @@ def visualize_path_between(
:param show_schema: If True, display the schema of the DAG if nodes have schema data provided
:return: graphviz object.
:param custom_style_function: Optional. Custom style function.
:param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html)
:raise ValueError: if the upstream or downstream node names are not found in the graph,
or there is no path between them.
"""
Expand Down Expand Up @@ -1256,6 +1271,7 @@ def visualize_path_between(
display_fields=show_schema,
custom_style_function=custom_style_function,
config=self.graph._config,
keep_dot=keep_dot,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
Expand Down Expand Up @@ -1523,6 +1539,7 @@ def visualize_materialization(
show_schema: bool = True,
custom_style_function: Callable = None,
bypass_validation: bool = False,
keep_dot: bool = False,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Visualizes materialization. This helps give you a sense of how materialization
will impact the DAG.
Expand Down Expand Up @@ -1572,6 +1589,7 @@ def visualize_materialization(
show_schema=show_schema,
custom_style_function=custom_style_function,
bypass_validation=bypass_validation,
keep_dot=keep_dot,
)

def validate_execution(
Expand Down
20 changes: 17 additions & 3 deletions hamilton/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import inspect
import logging
import os.path
import pathlib
import uuid
from enum import Enum
from types import ModuleType
Expand Down Expand Up @@ -747,6 +748,7 @@ def display_all(
deduplicate_inputs: bool = False,
display_fields: bool = True,
custom_style_function: Callable = None,
keep_dot: bool = False,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Displays & saves a dot file of the entire DAG structure constructed.

Expand All @@ -764,6 +766,7 @@ def display_all(
Can improve readability depending on the specifics of the DAG.
:param display_fields: If True, display fields in the graph if node has attached schema metadata
:param custom_style_function: Optional. Custom style function.
:param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html)
:return: the graphviz graph object if it was created. None if not.
"""
all_nodes = set()
Expand All @@ -789,6 +792,7 @@ def display_all(
display_fields=display_fields,
custom_style_function=custom_style_function,
config=self._config,
keep_dot=keep_dot,
)

def has_cycles(self, nodes: Set[node.Node], user_nodes: Set[node.Node]) -> bool:
Expand Down Expand Up @@ -823,7 +827,7 @@ def get_cycles(self, nodes: Set[node.Node], user_nodes: Set[node.Node]) -> List[
@staticmethod
def display(
nodes: Set[node.Node],
output_file_path: Optional[str] = "test-output/graph.gv",
output_file_path: Optional[str] = None,
render_kwargs: dict = None,
graphviz_kwargs: dict = None,
node_modifiers: Dict[str, Set[VisualizationNodeModifiers]] = None,
Expand All @@ -835,6 +839,7 @@ def display(
display_fields: bool = True,
custom_style_function: Callable = None,
config: dict = None,
keep_dot: bool = False,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Function to display the graph represented by the passed in nodes.

Expand Down Expand Up @@ -894,16 +899,25 @@ def display(
custom_style_function=custom_style_function,
config=config,
)
kwargs = {"view": False, "format": "png"} # default format = png
kwargs = {"format": "png"} # default format = png
if output_file_path: # infer format from path
output_file_path, suffix = os.path.splitext(output_file_path)
if suffix != "":
inferred_format = suffix.partition(".")[-1]
kwargs.update(format=inferred_format)
if render_kwargs and isinstance(render_kwargs, dict): # accept explicit format
kwargs.update(render_kwargs)
# .render()` and `.pipe()` have quirks to handle separately
# - `render()` accepts a `view` kwarg
# - `render()` will append it's kwarg `format` to the filename
if output_file_path:
dot.render(output_file_path, **kwargs)
if keep_dot:
kwargs["view"] = kwargs.get("view", False)
dot.render(output_file_path, **kwargs)
else:
kwargs.pop("view", None)
output_file_path = f"{output_file_path}.{kwargs['format']}"
pathlib.Path(output_file_path).write_bytes(dot.pipe(**kwargs))
return dot

def get_impacted_nodes(self, var_changes: List[str]) -> Set[node.Node]:
Expand Down
Loading