Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Improved DAG visualization #512

Merged
merged 10 commits into from
Nov 3, 2023
265 changes: 216 additions & 49 deletions hamilton/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

logger = logging.getLogger(__name__)

PATH_COLOR = "red"


class VisualizationNodeModifiers(Enum):
"""Enum of all possible node modifiers for visualization."""
Expand Down Expand Up @@ -140,6 +138,9 @@ def create_graphviz_graph(
graphviz_kwargs: dict,
node_modifiers: Dict[str, Set[VisualizationNodeModifiers]],
strictly_display_only_nodes_passed_in: bool,
left_to_right: bool = True,
zilto marked this conversation as resolved.
Show resolved Hide resolved
hide_inputs: bool = False,
deduplicate_inputs: bool = False,
) -> "graphviz.Digraph": # noqa: F821
"""Helper function to create a graphviz graph.

Expand All @@ -150,67 +151,233 @@ def create_graphviz_graph(
:param node_modifiers: A dictionary of node names to dictionaries of node attributes to modify.
:param strictly_display_only_nodes_passed_in: If True, only display the nodes passed in. Else defaults to displaying
also what nodes a node depends on (i.e. all nodes that feed into it).
:param left_to_right: If True, the DAG is displayed from left to right. Else, top to bottom. This value can
be overriden by the value of `graphviz_kwargs['graph_attr']['rankdir']`
:return: a graphviz.Digraph; use this to render/save a graph representation.
"""
PATH_COLOR = "#7A3B69"

import graphviz

digraph = graphviz.Digraph(comment=comment, **graphviz_kwargs)
def _get_node_label(n: node.Node, name=None, type_=None) -> str:
zilto marked this conversation as resolved.
Show resolved Hide resolved
name = n.name if name is None else name
type_ = n.type.__name__ if type_ is None else type_
return f"<<b>{name}</b><br /><br /><i>{type_}</i>>"

def _get_input_label(input_nodes: frozenset[node.Node]) -> str:
rows = [f"<tr><td>{dep.name}</td><td>{dep.type.__name__}</td></tr>" for dep in input_nodes]
return f"<<table border=\"0\">{''.join(rows)}</table>>"

def _get_node_type(n: node.Node) -> str:
if (
n._node_source == node.NodeType.EXTERNAL
and n._originating_functions is None
and n._depended_on_by
):
return "input"
elif (
n._node_source == node.NodeType.EXTERNAL
and n._originating_functions is None
and not n._depended_on_by
):
return "config"
else:
return "function"

def _get_node_style(node_type: str):
fontname = "Helvetica"

if node_type == "config":
node_style = dict(
shape="note",
style="",
fontname=fontname,
)
elif node_type == "input":
node_style = dict(
shape="rectangle",
margin="0.15",
style="dashed",
fontname=fontname,
)
elif node_type == "materializer":
node_style = dict(
shape="cylinder",
margin="0.15,0.1",
fontname=fontname,
)
else: # this is a function or else
node_style = dict(
shape="rectangle",
margin="0.15",
style="rounded,filled",
fillcolor="#b4d8e4",
fontname=fontname,
)

return node_style

def _get_function_modifier_style(modifier: str):
if modifier == "output":
modifier_style = dict(fillcolor="#FFC857")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will overwrite the prior fillcolor in certain cases, no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but the only type of node with a fillcolor is "function nodes" (the default).

I wonder how people design visualization software, but would it make sense to have a sort of lexicon of all the possible combinations for internal purposes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think I'm likely overthinking it though? As in, we can see what feedback we get?

elif modifier == "collect":
modifier_style = dict(peripheries="2", color="#EA5556")
elif modifier == "expand":
modifier_style = dict(peripheries="2", color="#56E39F")
elif modifier == "override":
modifier_style = dict(style="filled,diagonals")
elif modifier == "materializer":
modifier_style = dict(shape="cylinder")
else:
modifier_style = dict()

return modifier_style

def _get_edge_style(node_type: str):
zilto marked this conversation as resolved.
Show resolved Hide resolved
if node_type == "expand":
edge_style = dict(
dir="both",
arrowhead="crow",
arrowtail="none",
)
elif node_type == "collect":
edge_style = dict(dir="both", arrowtail="crow")
else:
edge_style = dict()

return edge_style

def _get_legend(node_types):
legend_subgraph = graphviz.Digraph(
name="cluster_legend", # needs to start with `cluster_` for graphviz layout
graph_attr=dict(
label="Legend",
rank="same",
fontname="helvetica",
),
)

sorted_types = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit -- maybe put these in order of commonality? So scan order is useful?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved materializer downwards

It's a bespoke ordering, but I thought about having config and input first because they most often appear at the top of the graph near the legend. Then, all others are of function type with modifiers.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

"config",
"input",
"materializer",
"function",
"output",
"override",
"expand",
"collect",
]

for node_type in sorted(node_types, key=lambda t: sorted_types.index(t)):
node_style = _get_node_style(node_type)
modifier_style = _get_function_modifier_style(node_type)
node_style.update(**modifier_style)
legend_subgraph.node(name=node_type, **node_style)

return legend_subgraph

rankdir = "LR" if left_to_right is False else "TB"
# handle default values in nested dict
digraph_attr = dict(
comment=comment,
graph_attr=dict(
rankdir=rankdir,
ranksep="0.4",
compound="true",
concentrate="true",
),
)
digraph_attr.update(**graphviz_kwargs)
digraph = graphviz.Digraph(**digraph_attr)

# create nodes
seen_node_types = set()
for n in nodes:
label = n.name
other_args = {}
# checks if the node has any modifiers
if n.name in node_modifiers:
modifiers = node_modifiers[n.name]
# if node is an output, then modify the node to be a rectangle
label = _get_node_label(n)
node_type = _get_node_type(n)
seen_node_types.add(node_type)
if node_type == "input":
continue

node_style = _get_node_style(node_type)

# prefer having the conditions explicit for now since they rely on
# VisualizationNodeModifiers and node.Node.node_role
if modifiers := node_modifiers.get(n.name):
if VisualizationNodeModifiers.IS_OUTPUT in modifiers:
other_args["shape"] = "rectangle"
modifier_style = _get_function_modifier_style("output")
node_style.update(**modifier_style)
seen_node_types.add("output")

if VisualizationNodeModifiers.IS_OVERRIDE in modifiers:
modifier_style = _get_function_modifier_style("override")
node_style.update(**modifier_style)
seen_node_types.add("override")

if VisualizationNodeModifiers.IS_PATH in modifiers:
other_args["color"] = PATH_COLOR
node_style["color"] = PATH_COLOR

if VisualizationNodeModifiers.IS_USER_INPUT in modifiers:
other_args["style"] = "dashed"
label = f"Input: {n.name}"
if n.node_role == node.NodeType.EXPAND:
modifier_style = _get_function_modifier_style("expand")
node_style.update(**modifier_style)
seen_node_types.add("expand")

if VisualizationNodeModifiers.IS_OVERRIDE in modifiers:
other_args["style"] = "dashed"
label = f"Override: {n.name}"
is_expand_node = n.node_role == node.NodeType.EXPAND
is_collect_node = n.node_role == node.NodeType.COLLECT
if n.node_role == node.NodeType.COLLECT:
modifier_style = _get_function_modifier_style("collect")
node_style.update(**modifier_style)
seen_node_types.add("collect")

if n._tags.get("hamilton.data_saver"):
materializer_type = n._tags["hamilton.data_saver.classname"]
label = _get_node_label(n, type_=materializer_type)
modifier_style = _get_function_modifier_style("materializer")
node_style.update(**modifier_style)
seen_node_types.add("materializer")

if is_collect_node or is_expand_node:
other_args["peripheries"] = "2"
digraph.node(n.name, label=label, **other_args)
digraph.node(n.name, label=label, **node_style)

# create edges
input_sets = dict()
for n in nodes:
input_nodes = set()

for n in list(nodes):
for d in n.dependencies:
if strictly_display_only_nodes_passed_in and d not in nodes:
continue
if (
d not in nodes
and d.name in node_modifiers
and VisualizationNodeModifiers.IS_USER_INPUT in node_modifiers[d.name]
):
digraph.node(d.name, label=f"Input: {d.name}", style="dashed")
from_modifiers = node_modifiers.get(d.name, set())
to_modifiers = node_modifiers.get(n.name, set())
other_args = {}
if (
VisualizationNodeModifiers.IS_PATH in from_modifiers
and VisualizationNodeModifiers.IS_PATH in to_modifiers
):
other_args["color"] = PATH_COLOR
is_collect_edge = n.node_role == node.NodeType.COLLECT
is_expand_edge = d.node_role == node.NodeType.EXPAND

if is_collect_edge:
other_args["dir"] = "both"
other_args["arrowtail"] = "crow"

if is_expand_edge:
other_args["dir"] = "both"
other_args["arrowhead"] = "crow"
other_args["arrowtail"] = "none"
digraph.edge(d.name, n.name, **other_args)

node_type = _get_node_type(d)
if node_type == "input":
input_nodes.add(d)
continue
zilto marked this conversation as resolved.
Show resolved Hide resolved

edge_style = _get_edge_style(d)
digraph.edge(d.name, n.name, **edge_style)

# skip input node creation
if hide_inputs:
continue

if len(input_nodes) > 0:
input_node_name = f"{n.name}_inputs"
zilto marked this conversation as resolved.
Show resolved Hide resolved

# following block is for input node deduplication
input_nodes = frozenset(input_nodes)
if existing_input_name := input_sets.get(input_nodes):
zilto marked this conversation as resolved.
Show resolved Hide resolved
digraph.edge(existing_input_name, n.name)
continue

# allow duplicate input nodes by never storing keys
if deduplicate_inputs:
input_sets[input_nodes] = input_node_name

# create input node
node_label = _get_input_label(input_nodes)
node_style = _get_node_style("input")
digraph.node(name=input_node_name, label=node_label, **node_style)
# create edge for input node
digraph.edge(input_node_name, n.name)

digraph.subgraph(_get_legend(seen_node_types))
return digraph


Expand Down