From a6697273ddf461cc25ec476f163a7b6b60c9b68d Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Sat, 30 Apr 2022 21:57:55 -0700 Subject: [PATCH] Exposes passing kwargs to graphviz object MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This should be a backwards compatible API change for everything. To be able to control, for instance the ratio, of the produced image, it's not a render time argument. Instead, it's a property of the graph. To be able to manipulate it, we need to wire through passing in attributes. e.g. https://graphviz.org/doc/info/attrs.html shows what could be set. These attributes are either graph, node, or edge attributes. The kwargs I'm specifically trying to expose are: ``` graph_attr – Mapping of (attribute, value) pairs for the graph. node_attr – Mapping of (attribute, value) pairs set for all nodes. edge_attr – Mapping of (attribute, value) pairs set for all edges. ``` The above was taken from https://graphviz.readthedocs.io/en/stable/api.html?highlight=ratio#digraph. Note, this will enable anything else to be also passed to the graph, which I think is a useful thing too. Since that will allow someone who knows graphviz to really customize it I think. --- hamilton/driver.py | 29 +++++++++++++++++++++++------ hamilton/graph.py | 31 +++++++++++++++++++++++-------- tests/test_graph.py | 3 ++- 3 files changed, 48 insertions(+), 15 deletions(-) diff --git a/hamilton/driver.py b/hamilton/driver.py index 971bc861..490323bd 100644 --- a/hamilton/driver.py +++ b/hamilton/driver.py @@ -165,16 +165,20 @@ def list_available_variables(self) -> List[Variable]: """ return [Variable(node.name, node.type, node.tags) for node in self.graph.get_nodes()] - def display_all_functions(self, output_file_path: str, render_kwargs: dict = None): + def display_all_functions(self, output_file_path: str, render_kwargs: dict = None, graphviz_kwargs: dict = None): """Displays the graph of all functions loaded! :param output_file_path: the full URI of path + file name to save the dot file to. E.g. 'some/path/graph-all.dot' :param render_kwargs: a dictionary of values we'll pass to graphviz render function. Defaults to viewing. If you do not want to view the file, pass in `{'view':False}`. + See https://graphviz.readthedocs.io/en/stable/api.html#graphviz.Graph.render for other options. + :param graphviz_kwargs: Optional. Kwargs to be passed to the graphviz graph object to configure it. + E.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image. + See https://graphviz.org/doc/info/attrs.html for options. """ try: - self.graph.display_all(output_file_path, render_kwargs) + self.graph.display_all(output_file_path, render_kwargs, graphviz_kwargs) except ImportError as e: logger.warning(f'Unable to import {e}', exc_info=True) @@ -182,7 +186,8 @@ def visualize_execution(self, final_vars: List[str], output_file_path: str, render_kwargs: dict, - inputs: Dict[str, Any] = None): + inputs: Dict[str, Any] = None, + graphviz_kwargs: dict = None): """Visualizes Execution. Note: overrides are not handled at this time. @@ -192,12 +197,17 @@ def visualize_execution(self, E.g. 'some/path/graph.dot' :param render_kwargs: a dictionary of values we'll pass to graphviz render function. Defaults to viewing. If you do not want to view the file, pass in `{'view':False}`. + See https://graphviz.readthedocs.io/en/stable/api.html#graphviz.Graph.render for other options. :param inputs: Optional. Runtime inputs to the DAG. + :param graphviz_kwargs: Optional. Kwargs to be passed to the graphviz graph object to configure it. + E.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image. + See https://graphviz.org/doc/info/attrs.html for options. """ nodes, user_nodes = self.graph.get_upstream_nodes(final_vars, inputs) self.validate_inputs(user_nodes, inputs) try: - self.graph.display(nodes, user_nodes, output_file_path, render_kwargs=render_kwargs) + self.graph.display(nodes, user_nodes, output_file_path, + render_kwargs=render_kwargs, graphviz_kwargs=graphviz_kwargs) except ImportError as e: logger.warning(f'Unable to import {e}', exc_info=True) @@ -221,7 +231,11 @@ def what_is_downstream_of(self, *node_names: str) -> List[Variable]: downstream_nodes = self.graph.get_impacted_nodes(list(node_names)) return [Variable(node.name, node.type, node.tags) for node in downstream_nodes] - def display_downstream_of(self, *node_names: str, output_file_path: str, render_kwargs: dict): + def display_downstream_of(self, + *node_names: str, + output_file_path: str, + render_kwargs: dict, + graphviz_kwargs: dict): """Creates a visualization of the DAG starting from the passed in function name(s). Note: for any "node" visualized, we will also add its parents to the visualization as well, so @@ -232,10 +246,13 @@ def display_downstream_of(self, *node_names: str, output_file_path: str, render_ E.g. 'some/path/graph.dot' :param render_kwargs: a dictionary of values we'll pass to graphviz render function. Defaults to viewing. If you do not want to view the file, pass in `{'view':False}`. + :param graphviz_kwargs: Kwargs to be passed to the graphviz graph object to configure it. + E.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image. """ downstream_nodes = self.graph.get_impacted_nodes(list(node_names)) try: - self.graph.display(downstream_nodes, set(), output_file_path, render_kwargs=render_kwargs) + self.graph.display(downstream_nodes, set(), output_file_path, + render_kwargs=render_kwargs, graphviz_kwargs=graphviz_kwargs) except ImportError as e: logger.warning(f'Unable to import {e}', exc_info=True) diff --git a/hamilton/graph.py b/hamilton/graph.py index b286fe1b..efb02742 100644 --- a/hamilton/graph.py +++ b/hamilton/graph.py @@ -168,16 +168,19 @@ def create_function_graph(*modules: ModuleType, config: Dict[str, Any], adapter: return nodes -def create_graphviz_graph(nodes: Set[node.Node], user_nodes: Set[node.Node], comment: str) -> 'graphviz.Digraph': +def create_graphviz_graph(nodes: Set[node.Node], user_nodes: Set[node.Node], comment: str, + graphviz_kwargs: dict) -> 'graphviz.Digraph': """Helper function to create a graphviz graph. :param nodes: The set of computational nodes :param user_nodes: The set of nodes that the user is providing inputs for. :param comment: The comment to have on the graph. + :param graphviz_kwargs: kwargs to pass to create the graph. + e.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image. :return: a graphviz.Digraph; use this to render/save a graph representation. """ import graphviz - digraph = graphviz.Digraph(comment=comment) + digraph = graphviz.Digraph(comment=comment, **graphviz_kwargs) for n in nodes: digraph.node(n.name, label=n.name) for n in user_nodes: @@ -238,12 +241,17 @@ def config(self): def get_nodes(self) -> List[node.Node]: return list(self.nodes.values()) - def display_all(self, output_file_path: str = 'test-output/graph-all.gv', render_kwargs: dict = None): + def display_all(self, + output_file_path: str = 'test-output/graph-all.gv', + render_kwargs: dict = None, + graphviz_kwargs: dict = None): """Displays & saves a dot file of the entire DAG structure constructed. :param output_file_path: the place to save the files. :param render_kwargs: a dictionary of values we'll pass to graphviz render function. Defaults to viewing. If you do not want to view the file, pass in `{'view':False}`. + :param graphviz_kwargs: kwargs to be passed to the graphviz graph object to configure it. + e.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image. """ defined_nodes = set() user_nodes = set() @@ -254,7 +262,10 @@ def display_all(self, output_file_path: str = 'test-output/graph-all.gv', render defined_nodes.add(n) if render_kwargs is None: render_kwargs = {} - self.display(defined_nodes, user_nodes, output_file_path=output_file_path, render_kwargs=render_kwargs) + if graphviz_kwargs is None: + graphviz_kwargs = {} + self.display(defined_nodes, user_nodes, + output_file_path=output_file_path, render_kwargs=render_kwargs, graphviz_kwargs=graphviz_kwargs) def has_cycles(self, nodes: Set[node.Node], user_nodes: Set[node.Node]) -> bool: """Checks that the graph created does not contain cycles. @@ -289,13 +300,16 @@ def get_cycles(self, nodes: Set[node.Node], user_nodes: Set[node.Node]) -> List[ def display(nodes: Set[node.Node], user_nodes: Set[node.Node], output_file_path: str = 'test-output/graph.gv', - render_kwargs: dict = None): + render_kwargs: dict = None, + graphviz_kwargs: dict = None): """Function to display the graph represented by the passed in nodes. :param nodes: the set of nodes that need to be computed. :param user_nodes: the set of inputs that the user provided. :param output_file_path: the path where we want to store the a `dot` file + pdf picture. :param render_kwargs: kwargs to be passed to the render function to visualize. + :param graphviz_kwargs: kwargs to be passed to the graphviz graph object to configure it. + e.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image. """ # Check to see if optional dependencies have been installed. try: @@ -306,10 +320,11 @@ def display(nodes: Set[node.Node], '\n\n pip install sf-hamilton[visualization] or pip install graphviz \n\n' ) return - - dot = create_graphviz_graph(nodes, user_nodes, 'Dependency Graph') + if graphviz_kwargs is None: + graphviz_kwargs = {} + dot = create_graphviz_graph(nodes, user_nodes, 'Dependency Graph', graphviz_kwargs) kwargs = {'view': True} - if kwargs and isinstance(render_kwargs, dict): + if render_kwargs and isinstance(render_kwargs, dict): kwargs.update(render_kwargs) dot.render(output_file_path, **kwargs) diff --git a/tests/test_graph.py b/tests/test_graph.py index bc4ab11a..542c34b0 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -366,6 +366,7 @@ def test_create_graphviz_graph(): # why? because for some reason given the same graph, the output file isn't deterministic. expected = sorted(['// test-graph', 'digraph {', + '\tgraph [ratio=1]', '\tB [label=B]', '\tA [label=A]', '\tc [label=c]', @@ -381,7 +382,7 @@ def test_create_graphviz_graph(): '']) if '' in expected: expected.remove('') - digraph = graph.create_graphviz_graph(nodes, user_nodes, 'test-graph') + digraph = graph.create_graphviz_graph(nodes, user_nodes, 'test-graph', dict(graph_attr={'ratio': '1'})) actual = sorted(str(digraph).split('\n')) if '' in actual: actual.remove('')