diff --git a/docs/reference/api/python/contrib.rst b/docs/reference/api/python/contrib.rst index 0eb3024c2d08..52d3faff0fc4 100644 --- a/docs/reference/api/python/contrib.rst +++ b/docs/reference/api/python/contrib.rst @@ -93,6 +93,16 @@ tvm.contrib.random :members: +tvm.contrib.relay_viz +~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: tvm.contrib.relay_viz + :members: +.. automodule:: tvm.contrib.relay_viz.interface + :members: +.. automodule:: tvm.contrib.relay_viz.terminal + :members: + + tvm.contrib.rocblas ~~~~~~~~~~~~~~~~~~~ .. automodule:: tvm.contrib.rocblas diff --git a/gallery/how_to/work_with_relay/using_relay_viz.py b/gallery/how_to/work_with_relay/using_relay_viz.py new file mode 100644 index 000000000000..f61fc41f4f14 --- /dev/null +++ b/gallery/how_to/work_with_relay/using_relay_viz.py @@ -0,0 +1,159 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=line-too-long +""" +Use Relay Visualizer to Visualize Relay +============================================================ +**Author**: `Chi-Wei Wang `_ + +Relay IR module can contain lots of operations. Although an individual +operation is usually easy to understand, putting them together can cause +a complicated, hard-to-read graph. Things can get even worse with optimiztion-passes +coming into play. + +This utility visualizes an IR module as nodes and edges. It defines a set of interfaces including +parser, plotter(renderer), graph, node, and edges. +A default parser is provided. Users can implement their own renderers to render the graph. + +Here we use a renderer rendering graph in the text-form. +It is a lightweight, AST-like visualizer, inspired by `clang ast-dump `_. +We will introduce how to implement customized parsers and renderers through interface classes. +""" +from typing import ( + Dict, + Union, + Tuple, + List, +) +import tvm +from tvm import relay +from tvm.contrib import relay_viz +from tvm.contrib.relay_viz.interface import ( + VizEdge, + VizNode, + VizParser, +) +from tvm.contrib.relay_viz.terminal import ( + TermGraph, + TermPlotter, + TermVizParser, +) + +###################################################################### +# Define a Relay IR Module with multiple GlobalVar +# ------------------------------------------------ +# Let's build an example Relay IR Module containing multiple ``GlobalVar``. +# We define an ``add`` function and call it in the main function. +data = relay.var("data") +bias = relay.var("bias") +add_op = relay.add(data, bias) +add_func = relay.Function([data, bias], add_op) +add_gvar = relay.GlobalVar("AddFunc") + +input0 = relay.var("input0") +input1 = relay.var("input1") +input2 = relay.var("input2") +add_01 = relay.Call(add_gvar, [input0, input1]) +add_012 = relay.Call(add_gvar, [input2, add_01]) +main_func = relay.Function([input0, input1, input2], add_012) +main_gvar = relay.GlobalVar("main") + +mod = tvm.IRModule({main_gvar: main_func, add_gvar: add_func}) + +###################################################################### +# Render the graph with Relay Visualizer on the terminal +# ------------------------------------------------------ +# The terminal can show a Relay IR module in text similar to clang AST-dump. +# We should see ``main`` and ``AddFunc`` function. ``AddFunc`` is called twice in the ``main`` function. +viz = relay_viz.RelayVisualizer(mod) +viz.render() + +###################################################################### +# Customize Parser for Interested Relay Types +# ------------------------------------------- +# Sometimes we want to emphasize interested information, or parse things differently for a specific usage. +# It is possible to provide customized parsers as long as it obeys the interface. +# Here demostrate how to customize parsers for ``relay.var``. +# We need to implement abstract interface :py:class:`tvm.contrib.relay_viz.interface.VizParser`. +class YourAwesomeParser(VizParser): + def __init__(self): + self._delegate = TermVizParser() + + def get_node_edges( + self, + node: relay.Expr, + relay_param: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.Expr, str], + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + + if isinstance(node, relay.Var): + node = VizNode(node_to_id[node], "AwesomeVar", f"name_hint {node.name_hint}") + # no edge is introduced. So return an empty list. + return node, [] + + # delegate other types to the other parser. + return self._delegate.get_node_edges(node, relay_param, node_to_id) + + +###################################################################### +# Pass the parser and an interested renderer to visualizer. +# Here we just the terminal renderer. +viz = relay_viz.RelayVisualizer(mod, {}, TermPlotter(), YourAwesomeParser()) +viz.render() + +###################################################################### +# Customization around Graph and Plotter +# ------------------------------------------- +# Besides parsers, we can also customize graph and renderers by implementing +# abstract class :py:class:`tvm.contrib.relay_viz.interface.VizGraph` and +# :py:class:`tvm.contrib.relay_viz.interface.Plotter`. +# Here we override the ``TermGraph`` defined in ``terminal.py`` for easier demo. +# We add a hook duplicating above ``AwesomeVar``, and make ``TermPlotter`` use the new class. +class AwesomeGraph(TermGraph): + def node(self, viz_node): + # add the node first + super().node(viz_node) + # if it's AwesomeVar, duplicate it. + if viz_node.type_name == "AwesomeVar": + duplicated_id = f"duplciated_{viz_node.identity}" + duplicated_type = "double AwesomeVar" + super().node(VizNode(duplicated_id, duplicated_type, "")) + # connect the duplicated var to the original one + super().edge(VizEdge(duplicated_id, viz_node.identity)) + + +# override TermPlotter to use `AwesomeGraph` instead +class AwesomePlotter(TermPlotter): + def create_graph(self, name): + self._name_to_graph[name] = AwesomeGraph(name) + return self._name_to_graph[name] + + +viz = relay_viz.RelayVisualizer(mod, {}, AwesomePlotter(), YourAwesomeParser()) +viz.render() + +###################################################################### +# Summary +# ------- +# This tutorial demonstrates the usage of Relay Visualizer and customization. +# The class :py:class:`tvm.contrib.relay_viz.RelayVisualizer` is composed of interfaces +# defined in ``interface.py``. +# +# It is aimed for quick look-then-fix iterations. +# The constructor arguments are intended to be simple, while the customization is still +# possible through a set of interface classes. +# diff --git a/python/tvm/contrib/relay_viz/__init__.py b/python/tvm/contrib/relay_viz/__init__.py new file mode 100644 index 000000000000..32814b577d0d --- /dev/null +++ b/python/tvm/contrib/relay_viz/__init__.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relay IR Visualizer""" +from typing import Dict +import tvm +from tvm import relay +from .interface import ( + Plotter, + VizGraph, + VizParser, +) +from .terminal import ( + TermPlotter, + TermVizParser, +) + + +class RelayVisualizer: + """Relay IR Visualizer + + Parameters + ---------- + relay_mod: tvm.IRModule + Relay IR module. + relay_param: None | Dict[str, tvm.runtime.NDArray] + Relay parameter dictionary. Default `None`. + plotter: Plotter + An instance of class inheriting from Plotter interface. + Default is an instance of `terminal.TermPlotter`. + parser: VizParser + An instance of class inheriting from VizParser interface. + Default is an instance of `terminal.TermVizParser`. + """ + + def __init__( + self, + relay_mod: tvm.IRModule, + relay_param: Dict[str, tvm.runtime.NDArray] = None, + plotter: Plotter = None, + parser: VizParser = None, + ): + self._plotter = plotter if plotter is not None else TermPlotter() + self._relay_param = relay_param if relay_param is not None else {} + self._parser = parser if parser is not None else TermVizParser() + + global_vars = relay_mod.get_global_vars() + graph_names = [] + # If we have main function, put it to the first. + # Then main function can be shown on the top. + for gv_node in global_vars: + if gv_node.name_hint == "main": + graph_names.insert(0, gv_node.name_hint) + else: + graph_names.append(gv_node.name_hint) + + node_to_id = {} + # callback to generate an unique string-ID for nodes. + def traverse_expr(node): + if node in node_to_id: + return + node_to_id[node] = str(len(node_to_id)) + + for name in graph_names: + node_to_id.clear() + relay.analysis.post_order_visit(relay_mod[name], traverse_expr) + graph = self._plotter.create_graph(name) + self._add_nodes(graph, node_to_id) + + def _add_nodes(self, graph: VizGraph, node_to_id: Dict[relay.Expr, str]): + """add nodes and to the graph. + + Parameters + ---------- + graph : VizGraph + a VizGraph for nodes to be added to. + + node_to_id : Dict[relay.expr, str] + a mapping from nodes to an unique ID. + + relay_param : Dict[str, tvm.runtime.NDarray] + relay parameter dictionary. + """ + for node in node_to_id: + viz_node, viz_edges = self._parser.get_node_edges(node, self._relay_param, node_to_id) + if viz_node is not None: + graph.node(viz_node) + for edge in viz_edges: + graph.edge(edge) + + def render(self, filename: str = None) -> None: + self._plotter.render(filename=filename) diff --git a/python/tvm/contrib/relay_viz/interface.py b/python/tvm/contrib/relay_viz/interface.py new file mode 100644 index 000000000000..6e52f024b1c5 --- /dev/null +++ b/python/tvm/contrib/relay_viz/interface.py @@ -0,0 +1,323 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Abstract class used by :py:class:`tvm.contrib.relay_viz.RelayVisualizer`.""" +import abc +from typing import ( + Dict, + Union, + Tuple, + List, +) + +import tvm +from tvm import relay + +UNKNOWN_TYPE = "unknown" + + +class VizNode: + """VizNode carry node information for `VizGraph` interface. + + Parameters + ---------- + node_id: str + Unique identifier for this node. + node_type: str + Type of this node. + node_detail: str + Any supplement for this node such as attributes. + """ + + def __init__(self, node_id: str, node_type: str, node_detail: str): + self._id = node_id + self._type = node_type + self._detail = node_detail + + @property + def identity(self) -> Union[int, str]: + return self._id + + @property + def type_name(self) -> str: + return self._type + + @property + def detail(self) -> str: + return self._detail + + +class VizEdge: + """VizEdge connect two `VizNode`. + + Parameters + ---------- + start_node: str + The identifier of the node starting the edge. + end_node: str + The identifier of the node ending the edge. + """ + + def __init__(self, start_node: str, end_node: str): + self._start_node = start_node + self._end_node = end_node + + @property + def start(self) -> str: + return self._start_node + + @property + def end(self) -> str: + return self._end_node + + +class VizParser(abc.ABC): + """VizParser parses out a VizNode and VizEdges from a `relay.Expr`.""" + + @abc.abstractmethod + def get_node_edges( + self, + node: relay.Expr, + relay_param: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.Expr, str], + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + """Get VizNode and VizEdges for a `relay.Expr`. + + Parameters + ---------- + node : relay.Expr + relay.Expr which will be parsed and generate a node and edges. + + relay_param: Dict[str, tvm.runtime.NDArray] + relay parameters dictionary. + + node_to_id : Dict[relay.Expr, str] + This is a mapping from relay.Expr to a unique id, generated by `RelayVisualizer`. + + Returns + ------- + rv1 : Union[VizNode, None] + VizNode represent the relay.Expr. If the relay.Expr is not intended to introduce a node + to the graph, return None. + + rv2 : List[VizEdge] + a list of VizEdges to describe the connectivity of the relay.Expr. + Can be empty list to indicate no connectivity. + """ + + +class VizGraph(abc.ABC): + """Abstract class for graph, which is composed of nodes and edges.""" + + @abc.abstractmethod + def node(self, viz_node: VizNode) -> None: + """Add a node to the underlying graph. + Nodes in a Relay IR Module are expected to be added in the post-order. + + Parameters + ---------- + viz_node : VizNode + A `VizNode` instance. + """ + + @abc.abstractmethod + def edge(self, viz_edge: VizEdge) -> None: + """Add an edge to the underlying graph. + + Parameters + ---------- + id_start : VizEdge + A `VizEdge` instance. + """ + + +class DefaultVizParser(VizParser): + """DefaultVizParser provde a set of logics to parse a various relay types. + These logics are inspired and heavily based on + `visualize` function in https://tvm.apache.org/2020/07/14/bert-pytorch-tvm + """ + + def get_node_edges( + self, + node: relay.Expr, + relay_param: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.Expr, str], + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + if isinstance(node, relay.Function): + return self._function(node, node_to_id) + if isinstance(node, relay.expr.Call): + return self._call(node, node_to_id) + if isinstance(node, relay.expr.Var): + return self._var(node, relay_param, node_to_id) + if isinstance(node, relay.expr.Tuple): + return self._tuple(node, node_to_id) + if isinstance(node, relay.expr.TupleGetItem): + return self._tuple_get_item(node, node_to_id) + if isinstance(node, relay.expr.Constant): + return self._constant(node, node_to_id) + # GlobalVar possibly mean another global relay function, + # which is expected to in "Graph" level, not in "Node" level. + if isinstance(node, (relay.expr.GlobalVar, tvm.ir.Op)): + return None, [] + + viz_node = VizNode(node_to_id[node], UNKNOWN_TYPE, f"don't know how to parse {type(node)}") + viz_edges = [] + return viz_node, viz_edges + + def _var( + self, + node: relay.Expr, + relay_param: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.Expr, str], + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + """Render rule for a relay var node""" + + node_id = node_to_id[node] + name_hint = node.name_hint + node_detail = f"name_hint: {name_hint}" + node_type = "Var(Param)" if name_hint in relay_param else "Var(Input)" + + if node.type_annotation is not None: + if hasattr(node.type_annotation, "shape"): + shape = tuple(map(int, node.type_annotation.shape)) + dtype = node.type_annotation.dtype + node_detail = f"{node_detail}\nshape: {shape}\ndtype: {dtype}" + else: + node_detail = f"{node_detail}\ntype_annotation: {node.type_annotation}" + + # only node + viz_node = VizNode(node_id, node_type, node_detail) + viz_edges = [] + return viz_node, viz_edges + + def _function( + self, + node: relay.Expr, + node_to_id: Dict[relay.Expr, str], + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + """Render rule for a relay function node""" + node_details = [] + name = "" + func_attrs = node.attrs + if func_attrs: + node_details = [f"{k}: {func_attrs.get_str(k)}" for k in func_attrs.keys()] + # "Composite" might from relay.transform.MergeComposite + if "Composite" in func_attrs.keys(): + name = func_attrs["Composite"] + node_id = node_to_id[node] + + # Body -> FunctionNode + viz_node = VizNode(node_id, f"Func {name}", "\n".join(node_details)) + viz_edges = [VizEdge(node_to_id[node.body], node_id)] + return viz_node, viz_edges + + def _call( + self, + node: relay.Expr, + node_to_id: Dict[relay.Expr, str], + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + """Render rule for a relay call node""" + node_id = node_to_id[node] + op_name = UNKNOWN_TYPE + node_detail = [] + if isinstance(node.op, tvm.ir.Op): + op_name = node.op.name + if node.attrs: + node_detail = [f"{k}: {node.attrs.get_str(k)}" for k in node.attrs.keys()] + elif isinstance(node.op, relay.Function): + func_attrs = node.op.attrs + op_name = "Anonymous Func" + if func_attrs: + node_detail = [f"{k}: {func_attrs.get_str(k)}" for k in func_attrs.keys()] + # "Composite" might from relay.transform.MergeComposite + if "Composite" in func_attrs.keys(): + op_name = func_attrs["Composite"] + elif isinstance(node.op, relay.GlobalVar): + op_name = "GlobalVar" + node_detail = [f"GlobalVar.name_hint: {node.op.name_hint}"] + else: + op_name = str(type(node.op)).split(".")[-1].split("'")[0] + + # Arguments -> CallNode + viz_node = VizNode(node_id, f"Call {op_name}", "\n".join(node_detail)) + args = [node_to_id[arg] for arg in node.args] + viz_edges = [VizEdge(arg, node_id) for arg in args] + return viz_node, viz_edges + + def _tuple( + self, + node: relay.Expr, + node_to_id: Dict[relay.Expr, str], + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + node_id = node_to_id[node] + + # Fields -> TupleNode + viz_node = VizNode(node_id, "Tuple", "") + viz_edges = [VizEdge(node_to_id[field], node_id) for field in node.fields] + return viz_node, viz_edges + + def _tuple_get_item( + self, + node: relay.Expr, + node_to_id: Dict[relay.Expr, str], + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + node_id = node_to_id[node] + + # Tuple -> TupleGetItemNode + viz_node = VizNode(node_id, f"TupleGetItem", "idx: {node.index}") + viz_edges = [VizEdge(node_to_id[node.tuple_value], node_id)] + return viz_node, viz_edges + + def _constant( + self, + node: relay.Expr, + node_to_id: Dict[relay.Expr, str], + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + node_id = node_to_id[node] + node_detail = f"shape: {node.data.shape}, dtype: {node.data.dtype}" + + # only node + viz_node = VizNode(node_id, "Const", node_detail) + viz_edges = [] + return viz_node, viz_edges + + +class Plotter(abc.ABC): + """Plotter can render a collection of Graph interfaces to a file.""" + + @abc.abstractmethod + def create_graph(self, name: str) -> VizGraph: + """Create a VizGraph + + Parameters + ---------- + name : str + the name of the graph + + Return + ------ + rv1: an instance of class inheriting from VizGraph interface. + """ + + @abc.abstractmethod + def render(self, filename: str) -> None: + """Render the graph as a file. + + Parameters + ---------- + filename : str + see the definition of implemented class. + """ diff --git a/python/tvm/contrib/relay_viz/terminal.py b/python/tvm/contrib/relay_viz/terminal.py new file mode 100644 index 000000000000..7b72d9da4333 --- /dev/null +++ b/python/tvm/contrib/relay_viz/terminal.py @@ -0,0 +1,238 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Visualize Relay IR in AST text-form.""" + +from collections import deque +from typing import ( + Dict, + Union, + Tuple, + List, +) +import tvm +from tvm import relay +from .interface import ( + DefaultVizParser, + Plotter, + VizEdge, + VizGraph, + VizNode, +) + + +class TermVizParser(DefaultVizParser): + """`TermVizParser` parse nodes and edges for `TermPlotter`.""" + + def __init__(self): + self._default_parser = DefaultVizParser() + + def get_node_edges( + self, + node: relay.Expr, + relay_param: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.Expr, str], + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + """Parse a node and edges from a relay.Expr.""" + if isinstance(node, relay.Call): + return self._call(node, node_to_id) + if isinstance(node, relay.Let): + return self._let(node, node_to_id) + if isinstance(node, relay.GlobalVar): + return self._global_var(node, node_to_id) + if isinstance(node, relay.If): + return self._if(node, node_to_id) + if isinstance(node, tvm.ir.Op): + return self._op(node, node_to_id) + if isinstance(node, relay.Function): + return self._function(node, node_to_id) + + # Leverage logics from default parser. + return self._default_parser.get_node_edges(node, relay_param, node_to_id) + + def _call(self, node, node_to_id): + node_id = node_to_id[node] + viz_node = VizNode(node_id, "Call", "") + viz_edges = [VizEdge(node_to_id[node.op], node_id)] + for arg in node.args: + arg_id = node_to_id[arg] + viz_edges.append(VizEdge(arg_id, node_id)) + return viz_node, viz_edges + + def _let(self, node, node_to_id): + node_id = node_to_id[node] + viz_node = VizNode(node_id, "Let", "(var, val, body)") + viz_edges = [ + VizEdge(node_to_id[node.var], node_id), + VizEdge(node_to_id[node.value], node_id), + VizEdge(node_to_id[node.body], node_id), + ] + return viz_node, viz_edges + + def _global_var(self, node, node_to_id): + node_id = node_to_id[node] + viz_node = VizNode(node_id, "GlobalVar", node.name_hint) + viz_edges = [] + return viz_node, viz_edges + + def _if(self, node, node_to_id): + node_id = node_to_id[node] + viz_node = VizNode(node_id, "If", "(cond, true, false)") + viz_edges = [ + VizEdge(node_to_id[node.cond], node_id), + VizEdge(node_to_id[node.true_branch], node_id), + VizEdge(node_to_id[node.false_branch], node_id), + ] + return viz_node, viz_edges + + def _op(self, node, node_to_id): + node_id = node_to_id[node] + op_name = node.name + viz_node = VizNode(node_id, op_name, "") + viz_edges = [] + return viz_node, viz_edges + + def _function(self, node, node_to_id): + node_id = node_to_id[node] + viz_node = VizNode(node_id, "Func", str(node.params)) + viz_edges = [VizEdge(node_to_id[node.body], node_id)] + return viz_node, viz_edges + + +class TermNode: + """TermNode is aimed to generate text more suitable for terminal visualization.""" + + def __init__(self, viz_node: VizNode): + self.type = viz_node.type_name + # We don't want too many lines in a terminal. + self.other_info = viz_node.detail.replace("\n", ", ") + + +class TermGraph(VizGraph): + """Terminal graph for a relay IR Module + + Parameters + ---------- + name: str + name of this graph. + """ + + def __init__(self, name: str): + self._name = name + # A graph in adjacency list form. + # The key is source node, and the value is a list of destination nodes. + self._graph = {} + # a hash table for quick searching. + self._id_to_term_node = {} + # node_id in reversed post order + # That mean, root is the first node. + self._node_id_rpo = deque() + + def node(self, viz_node: VizNode) -> None: + """Add a node to the underlying graph. + Nodes in a Relay IR Module are expected to be added in the post-order. + + Parameters + ---------- + viz_node : VizNode + A `VizNode` instance. + """ + + self._node_id_rpo.appendleft(viz_node.identity) + + if viz_node.identity not in self._graph: + # Add the node into the graph. + self._graph[viz_node.identity] = [] + + # Create TermNode from VizNode + node = TermNode(viz_node) + self._id_to_term_node[viz_node.identity] = node + + def edge(self, viz_edge: VizEdge) -> None: + """Add an edge to the terminal graph. + + Parameters + ---------- + id_start : VizEdge + A `VizEdge` instance. + """ + # Take CallNode as an example, instead of "arguments point to CallNode", + # we want "CallNode points to arguments" in ast-dump form. + # + # The direction of edge is typically controlled by the implemented VizParser. + # Reverse start/end here simply because we leverage default parser implementation. + if viz_edge.end in self._graph: + self._graph[viz_edge.end].append(viz_edge.start) + else: + self._graph[viz_edge.end] = [viz_edge.start] + + def render(self) -> str: + """Draw a terminal graph + + Returns + ------- + rv1: str + text representing a graph. + """ + lines = [] + seen_node = set() + + def gen_line(indent, n_id): + if (indent, n_id) in seen_node: + return + seen_node.add((indent, n_id)) + + conn_symbol = ["|--", "`--"] + last = len(self._graph[n_id]) - 1 + for i, next_n_id in enumerate(self._graph[n_id]): + node = self._id_to_term_node[next_n_id] + lines.append( + f"{indent}{conn_symbol[1 if i==last else 0]}{node.type} {node.other_info}" + ) + next_indent = indent + # increase indent for the next level. + next_indent += " " if (i == last) else "| " + gen_line(next_indent, next_n_id) + + first_node_id = self._node_id_rpo[0] + first_node = self._id_to_term_node[first_node_id] + lines.append(f"@{self._name}({first_node.other_info})") + gen_line("", first_node_id) + + return "\n".join(lines) + + +class TermPlotter(Plotter): + """Terminal plotter""" + + def __init__(self): + self._name_to_graph = {} + + def create_graph(self, name): + self._name_to_graph[name] = TermGraph(name) + return self._name_to_graph[name] + + def render(self, filename): + """If filename is None, print to stdio. Otherwise, write to the filename.""" + lines = [] + for name in self._name_to_graph: + text_graph = self._name_to_graph[name].render() + lines.append(text_graph) + if filename is None: + print("\n".join(lines)) + else: + with open(filename, "w") as out_file: + out_file.write("\n".join(lines))