Skip to content

Commit

Permalink
Adds visualize_path_between function to driver
Browse files Browse the repository at this point in the history
This is useful for debugging and zooming in on a particular path.
People who want to get more out of lineage will want to use this
function to help them document, debug, and understand how
particular things are related.
  • Loading branch information
skrawcz committed May 22, 2023
1 parent da1aeeb commit 7f4d151
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 356 deletions.
Binary file added examples/hello_world/a_path.dot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/hello_world/my_dag.dot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
432 changes: 81 additions & 351 deletions examples/hello_world/my_notebook.ipynb

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions examples/hello_world/my_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,11 @@

# To visualize do `pip install "sf-hamilton[visualization]"` if you want these to work
dr.visualize_execution(output_columns, "./my_dag.dot", {"format": "png"})
dr.visualize_path_between(
"spend_mean",
"spend_zero_mean_unit_variance",
"./a_path.dot",
{"format": "png"},
strict_path_visualization=False,
)
# dr.display_all_functions("./my_full_dag.dot", {"format": "png"})
70 changes: 70 additions & 0 deletions hamilton/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,76 @@ def what_is_upstream_of(self, *node_names: str) -> List[Variable]:
upstream_nodes, _ = self.graph.get_upstream_nodes(list(node_names))
return [Variable.from_node(n) for n in upstream_nodes]

def visualize_path_between(
self,
upstream_node_name: str,
downstream_node_name: str,
output_file_path: Optional[str] = None,
render_kwargs: dict = None,
graphviz_kwargs: dict = None,
strict_path_visualization: bool = False,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Visualizes the path between two nodes.
This is useful for debugging and understanding the path between two nodes.
:param upstream_node_name: the name of the node that we want to start from.
:param downstream_node_name: the name of the node that we want to end at.
:param output_file_path: the full URI of path + file name to save the dot file to.
E.g. 'some/path/graph.dot'. Pass in None to skip saving any file.
:param render_kwargs: a dictionary of values we'll pass to graphviz render function. Defaults to viewing.
If you do not want to view the file, pass in `{'view':False}`.
:param graphviz_kwargs: Kwargs to be passed to the graphviz graph object to configure it.
E.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image.
:param strict_path_visualization: If True, only the nodes in the path will be visualized. If False, the
nodes in the path and their dependencies, i.e. parents, will be visualized.
:return: graphviz object.
:raise ValueError: if the upstream or downstream node names are not found in the graph,
or there is no path between them.
"""
if render_kwargs is None:
render_kwargs = {}
if graphviz_kwargs is None:
graphviz_kwargs = {}
all_variables = {n.name: n for n in self.graph.get_nodes()}
# ensure that the nodes exist
if upstream_node_name not in all_variables:
raise ValueError(f"Upstream node {upstream_node_name} not found in graph.")
if downstream_node_name not in all_variables:
raise ValueError(f"Downstream node {downstream_node_name} not found in graph.")

# set whether the node is user input
node_modifiers = {}
for n in self.graph.get_nodes():
if n.user_defined:
node_modifiers[n.name] = {"is_user_input": True}

# create nodes that constitute the path
downstream_nodes = set(self.graph.get_impacted_nodes([upstream_node_name]))
upstream_nodes, _ = self.graph.get_upstream_nodes([downstream_node_name])
upstream_nodes = set(upstream_nodes)
nodes_for_path = downstream_nodes.intersection(upstream_nodes)
if len(nodes_for_path) == 0:
raise ValueError(
f"No path found between {upstream_node_name} and {downstream_node_name}."
)
# add is path for node_modifier's dict
for n in nodes_for_path:
if n.name not in node_modifiers:
node_modifiers[n.name] = {}
node_modifiers[n.name]["is_path"] = True
try:
return self.graph.display(
nodes_for_path,
output_file_path,
render_kwargs=render_kwargs,
graphviz_kwargs=graphviz_kwargs,
node_modifiers=node_modifiers,
strictly_display_only_passed_in_nodes=strict_path_visualization,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)


if __name__ == "__main__":
"""some example test code"""
Expand Down
6 changes: 1 addition & 5 deletions hamilton/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,7 @@ def create_graphviz_graph(
for d in n.dependencies:
if strictly_display_only_nodes_passed_in and d not in nodes:
continue
if (
d not in nodes
and d.name in node_modifiers
and node_modifiers[d.name].get("is_user_input")
):
if d.name in node_modifiers and node_modifiers[d.name].get("is_user_input"):
digraph.node(d.name, label=f"Input: {d.name}", style="dashed")
# print(f"Adding edge from {d.name} to {n.name}")
digraph.edge(d.name, n.name)
Expand Down

0 comments on commit 7f4d151

Please sign in to comment.