Skip to content
This repository has been archived by the owner on Jul 3, 2023. It is now read-only.

Exposes passing kwargs to graphviz object #125

Merged
merged 1 commit into from
May 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions hamilton/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,24 +165,29 @@ def list_available_variables(self) -> List[Variable]:
"""
return [Variable(node.name, node.type, node.tags) for node in self.graph.get_nodes()]

def display_all_functions(self, output_file_path: str, render_kwargs: dict = None):
def display_all_functions(self, output_file_path: str, render_kwargs: dict = None, graphviz_kwargs: dict = None):
"""Displays the graph of all functions loaded!

:param output_file_path: the full URI of path + file name to save the dot file to.
E.g. 'some/path/graph-all.dot'
:param render_kwargs: a dictionary of values we'll pass to graphviz render function. Defaults to viewing.
If you do not want to view the file, pass in `{'view':False}`.
See https://graphviz.readthedocs.io/en/stable/api.html#graphviz.Graph.render for other options.
:param graphviz_kwargs: Optional. Kwargs to be passed to the graphviz graph object to configure it.
E.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image.
See https://graphviz.org/doc/info/attrs.html for options.
"""
try:
self.graph.display_all(output_file_path, render_kwargs)
self.graph.display_all(output_file_path, render_kwargs, graphviz_kwargs)
except ImportError as e:
logger.warning(f'Unable to import {e}', exc_info=True)

def visualize_execution(self,
final_vars: List[str],
output_file_path: str,
render_kwargs: dict,
inputs: Dict[str, Any] = None):
inputs: Dict[str, Any] = None,
graphviz_kwargs: dict = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain the difference between this and **render_kwargs? I think it makes sense but its not obvious.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Render is called after you construct the graph and want to create an image from it.

https://graphviz.readthedocs.io/en/stable/api.html#graphviz.Graph.render

"""Visualizes Execution.

Note: overrides are not handled at this time.
Expand All @@ -192,12 +197,17 @@ def visualize_execution(self,
E.g. 'some/path/graph.dot'
:param render_kwargs: a dictionary of values we'll pass to graphviz render function. Defaults to viewing.
If you do not want to view the file, pass in `{'view':False}`.
See https://graphviz.readthedocs.io/en/stable/api.html#graphviz.Graph.render for other options.
:param inputs: Optional. Runtime inputs to the DAG.
:param graphviz_kwargs: Optional. Kwargs to be passed to the graphviz graph object to configure it.
E.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image.
See https://graphviz.org/doc/info/attrs.html for options.
"""
nodes, user_nodes = self.graph.get_upstream_nodes(final_vars, inputs)
self.validate_inputs(user_nodes, inputs)
try:
self.graph.display(nodes, user_nodes, output_file_path, render_kwargs=render_kwargs)
self.graph.display(nodes, user_nodes, output_file_path,
render_kwargs=render_kwargs, graphviz_kwargs=graphviz_kwargs)
except ImportError as e:
logger.warning(f'Unable to import {e}', exc_info=True)

Expand All @@ -221,7 +231,11 @@ def what_is_downstream_of(self, *node_names: str) -> List[Variable]:
downstream_nodes = self.graph.get_impacted_nodes(list(node_names))
return [Variable(node.name, node.type, node.tags) for node in downstream_nodes]

def display_downstream_of(self, *node_names: str, output_file_path: str, render_kwargs: dict):
def display_downstream_of(self,
*node_names: str,
output_file_path: str,
render_kwargs: dict,
graphviz_kwargs: dict):
"""Creates a visualization of the DAG starting from the passed in function name(s).

Note: for any "node" visualized, we will also add its parents to the visualization as well, so
Expand All @@ -232,10 +246,13 @@ def display_downstream_of(self, *node_names: str, output_file_path: str, render_
E.g. 'some/path/graph.dot'
:param render_kwargs: a dictionary of values we'll pass to graphviz render function. Defaults to viewing.
If you do not want to view the file, pass in `{'view':False}`.
:param graphviz_kwargs: Kwargs to be passed to the graphviz graph object to configure it.
E.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image.
"""
downstream_nodes = self.graph.get_impacted_nodes(list(node_names))
try:
self.graph.display(downstream_nodes, set(), output_file_path, render_kwargs=render_kwargs)
self.graph.display(downstream_nodes, set(), output_file_path,
render_kwargs=render_kwargs, graphviz_kwargs=graphviz_kwargs)
except ImportError as e:
logger.warning(f'Unable to import {e}', exc_info=True)

Expand Down
31 changes: 23 additions & 8 deletions hamilton/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,19 @@ def create_function_graph(*modules: ModuleType, config: Dict[str, Any], adapter:
return nodes


def create_graphviz_graph(nodes: Set[node.Node], user_nodes: Set[node.Node], comment: str) -> 'graphviz.Digraph':
def create_graphviz_graph(nodes: Set[node.Node], user_nodes: Set[node.Node], comment: str,
graphviz_kwargs: dict) -> 'graphviz.Digraph':
"""Helper function to create a graphviz graph.

:param nodes: The set of computational nodes
:param user_nodes: The set of nodes that the user is providing inputs for.
:param comment: The comment to have on the graph.
:param graphviz_kwargs: kwargs to pass to create the graph.
e.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image.
:return: a graphviz.Digraph; use this to render/save a graph representation.
"""
import graphviz
digraph = graphviz.Digraph(comment=comment)
digraph = graphviz.Digraph(comment=comment, **graphviz_kwargs)
for n in nodes:
digraph.node(n.name, label=n.name)
for n in user_nodes:
Expand Down Expand Up @@ -238,12 +241,17 @@ def config(self):
def get_nodes(self) -> List[node.Node]:
return list(self.nodes.values())

def display_all(self, output_file_path: str = 'test-output/graph-all.gv', render_kwargs: dict = None):
def display_all(self,
output_file_path: str = 'test-output/graph-all.gv',
render_kwargs: dict = None,
graphviz_kwargs: dict = None):
"""Displays & saves a dot file of the entire DAG structure constructed.

:param output_file_path: the place to save the files.
:param render_kwargs: a dictionary of values we'll pass to graphviz render function. Defaults to viewing.
If you do not want to view the file, pass in `{'view':False}`.
:param graphviz_kwargs: kwargs to be passed to the graphviz graph object to configure it.
e.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image.
"""
defined_nodes = set()
user_nodes = set()
Expand All @@ -254,7 +262,10 @@ def display_all(self, output_file_path: str = 'test-output/graph-all.gv', render
defined_nodes.add(n)
if render_kwargs is None:
render_kwargs = {}
self.display(defined_nodes, user_nodes, output_file_path=output_file_path, render_kwargs=render_kwargs)
if graphviz_kwargs is None:
graphviz_kwargs = {}
self.display(defined_nodes, user_nodes,
output_file_path=output_file_path, render_kwargs=render_kwargs, graphviz_kwargs=graphviz_kwargs)

def has_cycles(self, nodes: Set[node.Node], user_nodes: Set[node.Node]) -> bool:
"""Checks that the graph created does not contain cycles.
Expand Down Expand Up @@ -289,13 +300,16 @@ def get_cycles(self, nodes: Set[node.Node], user_nodes: Set[node.Node]) -> List[
def display(nodes: Set[node.Node],
user_nodes: Set[node.Node],
output_file_path: str = 'test-output/graph.gv',
render_kwargs: dict = None):
render_kwargs: dict = None,
graphviz_kwargs: dict = None):
"""Function to display the graph represented by the passed in nodes.

:param nodes: the set of nodes that need to be computed.
:param user_nodes: the set of inputs that the user provided.
:param output_file_path: the path where we want to store the a `dot` file + pdf picture.
:param render_kwargs: kwargs to be passed to the render function to visualize.
:param graphviz_kwargs: kwargs to be passed to the graphviz graph object to configure it.
e.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image.
"""
# Check to see if optional dependencies have been installed.
try:
Expand All @@ -306,10 +320,11 @@ def display(nodes: Set[node.Node],
'\n\n pip install sf-hamilton[visualization] or pip install graphviz \n\n'
)
return

dot = create_graphviz_graph(nodes, user_nodes, 'Dependency Graph')
if graphviz_kwargs is None:
graphviz_kwargs = {}
dot = create_graphviz_graph(nodes, user_nodes, 'Dependency Graph', graphviz_kwargs)
kwargs = {'view': True}
if kwargs and isinstance(render_kwargs, dict):
if render_kwargs and isinstance(render_kwargs, dict):
kwargs.update(render_kwargs)
dot.render(output_file_path, **kwargs)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def test_create_graphviz_graph():
# why? because for some reason given the same graph, the output file isn't deterministic.
expected = sorted(['// test-graph',
'digraph {',
'\tgraph [ratio=1]',
'\tB [label=B]',
'\tA [label=A]',
'\tc [label=c]',
Expand All @@ -381,7 +382,7 @@ def test_create_graphviz_graph():
''])
if '' in expected:
expected.remove('')
digraph = graph.create_graphviz_graph(nodes, user_nodes, 'test-graph')
digraph = graph.create_graphviz_graph(nodes, user_nodes, 'test-graph', dict(graph_attr={'ratio': '1'}))
actual = sorted(str(digraph).split('\n'))
if '' in actual:
actual.remove('')
Expand Down