diff --git a/docs/reference/api/python/contrib.rst b/docs/reference/api/python/contrib.rst index 0eb3024c2d08e..6cef118967536 100644 --- a/docs/reference/api/python/contrib.rst +++ b/docs/reference/api/python/contrib.rst @@ -92,6 +92,16 @@ tvm.contrib.random .. automodule:: tvm.contrib.random :members: +tvm.contrib.relay_viz +~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: tvm.contrib.relay_viz + :members: RelayVisualizer +.. autoattribute:: tvm.contrib.relay_viz.PlotterBackend.BOKEH +.. autoattribute:: tvm.contrib.relay_viz.PlotterBackend.TERMINAL +.. automodule:: tvm.contrib.relay_viz.plotter + :members: +.. automodule:: tvm.contrib.relay_viz.node_edge_gen + :members: tvm.contrib.rocblas ~~~~~~~~~~~~~~~~~~~ diff --git a/gallery/how_to/work_with_relay/using_relay_viz.py b/gallery/how_to/work_with_relay/using_relay_viz.py new file mode 100644 index 0000000000000..827be8ba3c0fa --- /dev/null +++ b/gallery/how_to/work_with_relay/using_relay_viz.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=line-too-long +""" +Use Relay Visualizer to Visualize Relay +============================================================ +**Author**: `Chi-Wei Wang `_ + +This is an introduction about using Relay Visualizer to visualize a Relay IR module. + +Relay IR module can contain lots of operations. Although individual +operations are usually easy to understand, they become complicated quickly +when you put them together. It could get even worse while optimiztion passes +come into play. + +This utility abstracts an IR module as graphs containing nodes and edges. +It provides a default parser to interpret an IR modules with nodes and edges. +Two renderer backends are also implemented to visualize them. + +Here we use a backend showing Relay IR module in the terminal for illustation. +It is a much more lightweight compared to another backend using `Bokeh `_. +See ``/python/tvm/contrib/relay_viz/README.md``. +Also we will introduce how to implement customized parsers and renderers through +some interfaces classes. +""" +from typing import ( + Dict, + Union, + Tuple, + List, +) +import tvm +from tvm import relay +from tvm.contrib import relay_viz +from tvm.contrib.relay_viz.node_edge_gen import ( + VizNode, + VizEdge, + NodeEdgeGenerator, +) +from tvm.contrib.relay_viz.terminal import ( + TermNodeEdgeGenerator, + TermGraph, + TermPlotter, +) + +###################################################################### +# Define a Relay IR Module with multiple GlobalVar +# ------------------------------------------------ +# Let's build an example Relay IR Module containing multiple ``GlobalVar``. +# We define an ``add`` function and call it in the main function. +data = relay.var("data") +bias = relay.var("bias") +add_op = relay.add(data, bias) +add_func = relay.Function([data, bias], add_op) +add_gvar = relay.GlobalVar("AddFunc") + +input0 = relay.var("input0") +input1 = relay.var("input1") +input2 = relay.var("input2") +add_01 = relay.Call(add_gvar, [input0, input1]) +add_012 = relay.Call(add_gvar, [input2, add_01]) +main_func = relay.Function([input0, input1, input2], add_012) +main_gvar = relay.GlobalVar("main") + +mod = tvm.IRModule({main_gvar: main_func, add_gvar: add_func}) + +###################################################################### +# Render the graph with Relay Visualizer on the terminal +# ------------------------------------------------------ +# The terminal backend can show a Relay IR module as in a text-form +# similar to `clang ast-dump `_. +# We should see ``main`` and ``AddFunc`` function. ``AddFunc`` is called twice in the ``main`` function. +viz = relay_viz.RelayVisualizer(mod, {}, relay_viz.PlotterBackend.TERMINAL) +viz.render() + +###################################################################### +# Customize Parser for Interested Relay Types +# ------------------------------------------- +# Sometimes the information shown by the default implementation is not suitable +# for a specific usage. It is possible to provide your own parser and renderer. +# Here demostrate how to customize parsers for ``relay.var``. +# We need to implement :py:class:`tvm.contrib.relay_viz.node_edge_gen.NodeEdgeGenerator` interface. +class YourAwesomeParser(NodeEdgeGenerator): + def __init__(self): + self._org_parser = TermNodeEdgeGenerator() + + def get_node_edges( + self, + node: relay.Expr, + relay_param: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.Expr, Union[int, str]], + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + + if isinstance(node, relay.Var): + node = VizNode(node_to_id[node], "AwesomeVar", f"name_hint {node.name_hint}") + # no edge is introduced. So return an empty list. + ret = (node, []) + return ret + + # delegate other types to the original parser. + return self._org_parser.get_node_edges(node, relay_param, node_to_id) + + +###################################################################### +# Pass a tuple of :py:class:`tvm.contrib.relay_viz.plotter.Plotter` and +# :py:class:`tvm.contrib.relay_viz.node_edge_gen.NodeEdgeGenerator` instances +# to ``RelayVisualizer``. Here we re-use the Plotter interface implemented inside +# ``relay_viz.terminal`` module. +viz = relay_viz.RelayVisualizer(mod, {}, (TermPlotter(), YourAwesomeParser())) +viz.render() + +###################################################################### +# More Customization around Graph and Plotter +# ------------------------------------------- +# All ``RelayVisualizer`` care about are interfaces defined in ``plotter.py`` and +# ``node_edge_generator.py``. We can override them to introduce custimized logics. +# For example, if we want the Graph to duplicate above ``AwesomeVar`` while it is added, +# we can override ``relay_viz.terminal.TermGraph.node``. +class AwesomeGraph(TermGraph): + def node(self, node_id, node_type, node_detail): + # add original node first + super().node(node_id, node_type, node_detail) + if node_type == "AwesomeVar": + duplicated_id = f"duplciated_{node_id}" + duplicated_type = "double AwesomeVar" + super().node(duplicated_id, duplicated_type, "") + # connect the duplicated var to the original one + super().edge(duplicated_id, node_id) + + +# override TermPlotter to return `AwesomeGraph` instead +class AwesomePlotter(TermPlotter): + def create_graph(self, name): + self._name_to_graph[name] = AwesomeGraph(name) + return self._name_to_graph[name] + + +viz = relay_viz.RelayVisualizer(mod, {}, (AwesomePlotter(), YourAwesomeParser())) +viz.render() + +###################################################################### +# Summary +# ------- +# This tutorial demonstrates the usage of Relay Visualizer. +# The class :py:class:`tvm.contrib.relay_viz.RelayVisualizer` is composed of interfaces +# defined in ``plotter.py`` and ``node_edge_generator.py``. It provides a single entry point +# while keeping the possibility of implementing customized visualizer in various cases. +# diff --git a/python/tvm/contrib/relay_viz/README.md b/python/tvm/contrib/relay_viz/README.md index bb6e964e8f074..c1d7d2245249b 100644 --- a/python/tvm/contrib/relay_viz/README.md +++ b/python/tvm/contrib/relay_viz/README.md @@ -28,6 +28,10 @@ This tool target to visualize Relay IR. ## Requirement +### Terminal Backend +1. TVM + +### Bokeh Backend 1. TVM 2. graphviz 2. pydot @@ -66,9 +70,9 @@ This utility is composed of two parts: `node_edge_gen.py` and `plotter.py`. `plotter.py` define interfaces of `Graph` and `Plotter`. `Plotter` is responsible to render a collection of `Graph`. -`node_edge_gen.py` define interfaces of converting Relay IR modules to nodes/edges consumed by `Graph`. Further, this python module also provide a default implementation for common relay types. +`node_edge_gen.py` define interfaces of converting Relay IR modules to nodes and edges. Further, this python module provide a default implementation for common relay types. If customization is wanted for a certain relay type, we can implement the `NodeEdgeGenerator` interface, handling that relay type accordingly, and delegate other types to the default implementation. See `_terminal.py` for an example usage. -These two interfaces are glued by the top level class `RelayVisualizer`, which passes a relay module to `NodeEdgeGenerator` and add nodes/edges to `Graph`. -Then, it render the plot by `Plotter.render()`. +These two interfaces are glued by the top level class `RelayVisualizer`, which passes a relay module to `NodeEdgeGenerator` and add nodes and edges to `Graph`. +Then, it render the plot by calling `Plotter.render()`. diff --git a/python/tvm/contrib/relay_viz/__init__.py b/python/tvm/contrib/relay_viz/__init__.py index 855adeb2e8ee4..1015f4781f6fd 100644 --- a/python/tvm/contrib/relay_viz/__init__.py +++ b/python/tvm/contrib/relay_viz/__init__.py @@ -28,14 +28,26 @@ class PlotterBackend(Enum): - """Enumeration for available plotters.""" + """Enumeration for available plotter backends.""" BOKEH = "bokeh" TERMINAL = "terminal" class RelayVisualizer: - """Relay IR Visualizer""" + """Relay IR Visualizer + + Parameters + ---------- + relay_mod : tvm.IRModule + Relay IR module. + relay_param: None | Dict[str, tvm.runtime.NDArray] + Relay parameter dictionary. Default `None`. + backend: PlotterBackend | Tuple[Plotter, NodeEdgeGenerator] + The backend used to render graphs. It can be a tuple of an implemented Plotter instance and + NodeEdgeGenerator instance to introduce customized parsing and visualization logics. + Default ``PlotterBackend.TERMINAL``. + """ def __init__( self, @@ -43,14 +55,6 @@ def __init__( relay_param: Union[None, Dict[str, tvm.runtime.NDArray]] = None, backend: Union[PlotterBackend, Tuple[Plotter, NodeEdgeGenerator]] = PlotterBackend.TERMINAL, ): - """Visualize Relay IR. - - Parameters - ---------- - relay_mod : tvm.IRModule, Relay IR module - relay_param: None | Dict[str, tvm.runtime.NDArray], Relay parameter dictionary. Default `None`. - backend: PlotterBackend | Tuple[Plotter, NodeEdgeGenerator], Default `PlotterBackend.TERMINAL`. - """ self._plotter, self._ne_generator = get_plotter_and_generator(backend) self._relay_param = relay_param if relay_param is not None else {} @@ -83,8 +87,10 @@ def _add_nodes(self, graph, node_to_id, relay_param): Parameters ---------- - graph : `plotter.Graph` + graph : plotter.Graph + node_to_id : Dict[relay.expr, str | int] + relay_param : Dict[str, tvm.runtime.NDarray] """ for node in node_to_id: @@ -102,11 +108,11 @@ def get_plotter_and_generator(backend): """Specify the Plottor and its NodeEdgeGenerator""" if isinstance(backend, (tuple, list)) and len(backend) == 2: if not isinstance(backend[0], Plotter): - raise ValueError(f"First element of backend should be derived from {type(Plotter)}") + raise ValueError(f"First element should be an instance derived from {type(Plotter)}") if not isinstance(backend[1], NodeEdgeGenerator): raise ValueError( - f"Second element of backend should be derived from {type(NodeEdgeGenerator)}" + f"Second element should be an instance derived from {type(NodeEdgeGenerator)}" ) return backend @@ -118,7 +124,7 @@ def get_plotter_and_generator(backend): # Basically we want to keep them optional. Users can choose plotters they want to install. if backend == PlotterBackend.BOKEH: # pylint: disable=import-outside-toplevel - from ._bokeh import ( + from .bokeh import ( BokehPlotter, BokehNodeEdgeGenerator, ) @@ -127,7 +133,7 @@ def get_plotter_and_generator(backend): ne_generator = BokehNodeEdgeGenerator() elif backend == PlotterBackend.TERMINAL: # pylint: disable=import-outside-toplevel - from ._terminal import ( + from .terminal import ( TermPlotter, TermNodeEdgeGenerator, ) diff --git a/python/tvm/contrib/relay_viz/_bokeh.py b/python/tvm/contrib/relay_viz/bokeh.py similarity index 98% rename from python/tvm/contrib/relay_viz/_bokeh.py rename to python/tvm/contrib/relay_viz/bokeh.py index 58716510c91a3..6ea82188463e0 100644 --- a/python/tvm/contrib/relay_viz/_bokeh.py +++ b/python/tvm/contrib/relay_viz/bokeh.py @@ -19,22 +19,9 @@ import functools import logging -_LOGGER = logging.getLogger(__name__) - import numpy as np - -try: - import pydot -except ImportError: - _LOGGER.critical("pydot library is required. You might want to run pip install pydot.") - raise - -try: - from bokeh.io import output_file, save -except ImportError: - _LOGGER.critical("bokeh library is required. You might want to run pip install bokeh.") - raise - +import pydot +from bokeh.io import output_file, save from bokeh.models import ( ColumnDataSource, CustomJS, @@ -63,6 +50,8 @@ from .node_edge_gen import DefaultNodeEdgeGenerator +_LOGGER = logging.getLogger(__name__) + # Use default node/edge generator BokehNodeEdgeGenerator = DefaultNodeEdgeGenerator diff --git a/python/tvm/contrib/relay_viz/node_edge_gen.py b/python/tvm/contrib/relay_viz/node_edge_gen.py index ffe54e7418979..13021a81a2867 100644 --- a/python/tvm/contrib/relay_viz/node_edge_gen.py +++ b/python/tvm/contrib/relay_viz/node_edge_gen.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""NodeEdgeGenerator interface""" +"""NodeEdgeGenerator interface for :py:class:`tvm.contrib.relay_viz.plotter.Graph`.""" import abc from typing import ( Dict, @@ -28,7 +28,7 @@ UNKNOWN_TYPE = "unknown" -class Node: +class VizNode: """Node carry information used by `plotter.Graph` interface.""" def __init__(self, node_id: Union[int, str], node_type: str, node_detail: str): @@ -49,8 +49,8 @@ def detail(self) -> str: return self._detail -class Edge: - """Edge for `plotter.Graph` interface.""" +class VizEdge: + """Edges for `plotter.Graph` interface.""" def __init__(self, start_node: Union[int, str], end_node: Union[int, str]): self._start_node = start_node @@ -74,10 +74,29 @@ def get_node_edges( node: relay.Expr, relay_param: Dict[str, tvm.runtime.NDArray], node_to_id: Dict[relay.Expr, Union[int, str]], - ) -> Tuple[Union[Node, None], List[Edge]]: - """Generate node and edges consumed by Graph interfaces - The returned tuple containing Node and a list of Edge instances. - Tuple[None, list[]] for null results. + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + """Generate node and edges consumed by Graph interfaces. + + Parameters + ---------- + node : relay.Expr + relay.Expr which will be parsed and generate a node and edges. + + relay_param: Dict[str, tvm.runtime.NDArray] + relay parameters dictionary. + + node_to_id : Dict[relay.Expr, Union[int, str]] + a mapping from relay.Expr to node id which should be unique. + + Returns + ------- + rv1 : Union[VizNode, None] + VizNode represent the relay.Expr. If the relay.Expr is not intended to introduce a node + to the graph, return None. + + rv2 : List[VizEdge] + a list of VizEdge to describe the connectivity of the relay.Expr. + Can be empty list to indicate no connectivity. """ @@ -88,59 +107,72 @@ class DefaultNodeEdgeGenerator(NodeEdgeGenerator): """ def __init__(self): - self.render_rules = {} - self.build_rules() + self._render_rules = {} + self._build_rules() - def var_node( + def get_node_edges( self, node: relay.Expr, relay_param: Dict[str, tvm.runtime.NDArray], node_to_id: Dict[relay.Expr, Union[int, str]], - ) -> Tuple[Union[Node, None], List[Edge]]: + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + try: + node_info, edge_info = self._render_rules[type(node)](node, relay_param, node_to_id) + except KeyError: + node_info = VizNode( + node_to_id[node], UNKNOWN_TYPE, f"don't know how to parse {type(node)}" + ) + edge_info = [] + return node_info, edge_info + + def _var_node( + self, + node: relay.Expr, + relay_param: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.Expr, Union[int, str]], + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: """Render rule for a relay var node""" node_id = node_to_id[node] name_hint = node.name_hint - node_detail = "" + node_detail = f"name_hint: {name_hint}" node_type = "Var(Param)" if name_hint in relay_param else "Var(Input)" if node.type_annotation is not None: if hasattr(node.type_annotation, "shape"): shape = tuple(map(int, node.type_annotation.shape)) dtype = node.type_annotation.dtype - node_detail = "name_hint: {}\nshape: {}\ndtype: {}".format(name_hint, shape, dtype) + node_detail = f"name_hint: {name_hint}\nshape: {shape}\ndtype: {dtype}" else: - node_detail = "name_hint: {}\ntype_annotation: {}".format( - name_hint, node.type_annotation - ) - node_info = Node(node_id, node_type, node_detail) + node_detail = f"name_hint: {name_hint}\ntype_annotation: {node.type_annotation}" + node_info = VizNode(node_id, node_type, node_detail) edge_info = [] return node_info, edge_info - def function_node( + def _function_node( self, node: relay.Expr, _: Dict[str, tvm.runtime.NDArray], # relay_param node_to_id: Dict[relay.Expr, Union[int, str]], - ) -> Tuple[Union[Node, None], List[Edge]]: + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: """Render rule for a relay function node""" node_details = [] name = "" func_attrs = node.attrs if func_attrs: - node_details = ["{}: {}".format(k, func_attrs.get_str(k)) for k in func_attrs.keys()] + node_details = [f"{k}: {func_attrs.get_str(k)}" for k in func_attrs.keys()] # "Composite" might from relay.transform.MergeComposite if "Composite" in func_attrs.keys(): name = func_attrs["Composite"] node_id = node_to_id[node] - node_info = Node(node_id, f"Func {name}", "\n".join(node_details)) - edge_info = [Edge(node_to_id[node.body], node_id)] + node_info = VizNode(node_id, f"Func {name}", "\n".join(node_details)) + edge_info = [VizEdge(node_to_id[node.body], node_id)] return node_info, edge_info - def call_node( + def _call_node( self, node: relay.Expr, _: Dict[str, tvm.runtime.NDArray], # relay_param node_to_id: Dict[relay.Expr, Union[int, str]], - ) -> Tuple[Union[Node, None], List[Edge]]: + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: """Render rule for a relay call node""" node_id = node_to_id[node] op_name = UNKNOWN_TYPE @@ -148,12 +180,12 @@ def call_node( if isinstance(node.op, tvm.ir.Op): op_name = node.op.name if node.attrs: - node_detail = ["{}: {}".format(k, node.attrs.get_str(k)) for k in node.attrs.keys()] + node_detail = [f"{k}: {node.attrs.get_str(k)}" for k in node.attrs.keys()] elif isinstance(node.op, relay.Function): func_attrs = node.op.attrs op_name = "Anonymous Func" if func_attrs: - node_detail = ["{}: {}".format(k, func_attrs.get_str(k)) for k in func_attrs.keys()] + node_detail = [f"{k}: {func_attrs.get_str(k)}" for k in func_attrs.keys()] # "Composite" might from relay.transform.MergeComposite if "Composite" in func_attrs.keys(): op_name = func_attrs["Composite"] @@ -163,71 +195,56 @@ def call_node( else: op_name = str(type(node.op)).split(".")[-1].split("'")[0] - node_info = Node(node_id, f"Call {op_name}", "\n".join(node_detail)) + node_info = VizNode(node_id, f"Call {op_name}", "\n".join(node_detail)) args = [node_to_id[arg] for arg in node.args] - edge_info = [Edge(arg, node_id) for arg in args] + edge_info = [VizEdge(arg, node_id) for arg in args] return node_info, edge_info - def tuple_node( + def _tuple_node( self, node: relay.Expr, _: Dict[str, tvm.runtime.NDArray], # relay_param node_to_id: Dict[relay.Expr, Union[int, str]], - ) -> Tuple[Union[Node, None], List[Edge]]: + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: node_id = node_to_id[node] - node_info = Node(node_id, "Tuple", "") - edge_info = [Edge(node_to_id[field], node_id) for field in node.fields] + node_info = VizNode(node_id, "Tuple", "") + edge_info = [VizEdge(node_to_id[field], node_id) for field in node.fields] return node_info, edge_info - def tuple_get_item_node( + def _tuple_get_item_node( self, node: relay.Expr, _: Dict[str, tvm.runtime.NDArray], # relay_param node_to_id: Dict[relay.Expr, Union[int, str]], - ) -> Tuple[Union[Node, None], List[Edge]]: + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: node_id = node_to_id[node] - node_info = Node(node_id, "TupleGetItem", "idx: {}".format(node.index)) - edge_info = [Edge(node_to_id[node.tuple_value], node_id)] + node_info = VizNode(node_id, f"TupleGetItem", "idx: {node.index}") + edge_info = [VizEdge(node_to_id[node.tuple_value], node_id)] return node_info, edge_info - def constant_node( + def _constant_node( self, node: relay.Expr, _: Dict[str, tvm.runtime.NDArray], # relay_param node_to_id: Dict[relay.Expr, Union[int, str]], - ) -> Tuple[Union[Node, None], List[Edge]]: + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: node_id = node_to_id[node] - node_detail = "shape: {}, dtype: {}".format(node.data.shape, node.data.dtype) - node_info = Node(node_id, "Const", node_detail) + node_detail = "shape: {node.data.shape}, dtype: {node.data.dtype}" + node_info = VizNode(node_id, "Const", node_detail) edge_info = [] return node_info, edge_info - def null(self, *_) -> Tuple[None, List[Edge]]: + def _null(self, *_) -> Tuple[None, List[VizEdge]]: return None, [] - def build_rules(self): - self.render_rules = { - tvm.relay.Function: self.function_node, - tvm.relay.expr.Call: self.call_node, - tvm.relay.expr.Var: self.var_node, - tvm.relay.expr.Tuple: self.tuple_node, - tvm.relay.expr.TupleGetItem: self.tuple_get_item_node, - tvm.relay.expr.Constant: self.constant_node, - tvm.relay.expr.GlobalVar: self.null, - tvm.ir.Op: self.null, + def _build_rules(self): + self._render_rules = { + tvm.relay.Function: self._function_node, + tvm.relay.expr.Call: self._call_node, + tvm.relay.expr.Var: self._var_node, + tvm.relay.expr.Tuple: self._tuple_node, + tvm.relay.expr.TupleGetItem: self._tuple_get_item_node, + tvm.relay.expr.Constant: self._constant_node, + tvm.relay.expr.GlobalVar: self._null, + tvm.ir.Op: self._null, } - - def get_node_edges( - self, - node: relay.Expr, - relay_param: Dict[str, tvm.runtime.NDArray], - node_to_id: Dict[relay.Expr, Union[int, str]], - ) -> Tuple[Union[Node, None], List[Edge]]: - try: - node_info, edge_info = self.render_rules[type(node)](node, relay_param, node_to_id) - except KeyError: - node_info = Node( - node_to_id[node], UNKNOWN_TYPE, f"don't know how to parse {type(node)}" - ) - edge_info = [] - return node_info, edge_info diff --git a/python/tvm/contrib/relay_viz/plotter.py b/python/tvm/contrib/relay_viz/plotter.py index 58ed2c02ebd99..de8c24c39a400 100644 --- a/python/tvm/contrib/relay_viz/plotter.py +++ b/python/tvm/contrib/relay_viz/plotter.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Abstract class for plotters.""" +"""Abstract class used by :py:class:`tvm.contrib.relay_viz.RelayVisualizer`.""" import abc from typing import Union @@ -28,9 +28,14 @@ def node(self, node_id: Union[int, str], node_type: str, node_detail: str) -> No Parameters ---------- - node_id : Union[int, str], Serve as the ID to the node. - node_type : str, the type of the node. - node_detail : str, the description of the node. + node_id : Union[int, str] + Serve as the ID to the node. + + node_type : str + the type of the node. + + node_detail : str + the description of the node. """ @abc.abstractmethod @@ -39,8 +44,11 @@ def edge(self, id_start: Union[int, str], id_end: Union[int, str]) -> None: Parameters ---------- - id_start : Union[int, str], the ID to the starting node. - id_end : Union[int, str], the ID to the ending node. + id_start : Union[int, str] + the ID to the starting node. + + id_end : Union[int, str] + the ID to the ending node. """ @@ -53,11 +61,12 @@ def create_graph(self, name: str) -> Graph: Parameters ---------- - name : string, the name of the graph + name : str + the name of the graph Return ------ - Graph instance. + rv1: class Graph """ @abc.abstractmethod @@ -66,5 +75,6 @@ def render(self, filename: str) -> None: Parameters ---------- - filename : string, see the definition of implemented class. + filename : str + see the definition of implemented class. """ diff --git a/python/tvm/contrib/relay_viz/_terminal.py b/python/tvm/contrib/relay_viz/terminal.py similarity index 85% rename from python/tvm/contrib/relay_viz/_terminal.py rename to python/tvm/contrib/relay_viz/terminal.py index 230bd88f5a521..82332d9649070 100644 --- a/python/tvm/contrib/relay_viz/_terminal.py +++ b/python/tvm/contrib/relay_viz/terminal.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Visualize Relay IR in AST text-form""" +"""Visualize Relay IR in AST text-form.""" from collections import deque from typing import ( @@ -33,8 +33,8 @@ ) from .node_edge_gen import ( - Node, - Edge, + VizNode, + VizEdge, NodeEdgeGenerator, DefaultNodeEdgeGenerator, ) @@ -51,7 +51,7 @@ def get_node_edges( node: relay.Expr, relay_param: Dict[str, tvm.runtime.NDArray], node_to_id: Dict[relay.Expr, Union[int, str]], - ) -> Tuple[Union[Node, None], List[Edge]]: + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: """Generate node and edges consumed by TermGraph interfaces""" if isinstance(node, relay.Call): return self._call_node(node, node_to_id) @@ -76,50 +76,50 @@ def get_node_edges( def _call_node(self, node, node_to_id): node_id = node_to_id[node] - node_info = Node(node_id, "Call", "") - edge_info = [Edge(node_to_id[node.op], node_id)] + node_info = VizNode(node_id, "Call", "") + edge_info = [VizEdge(node_to_id[node.op], node_id)] for arg in node.args: arg_nid = node_to_id[arg] - edge_info.append(Edge(arg_nid, node_id)) + edge_info.append(VizEdge(arg_nid, node_id)) return node_info, edge_info def _let_node(self, node, node_to_id): node_id = node_to_id[node] - node_info = Node(node_id, "Let", "(var, val, body)") + node_info = VizNode(node_id, "Let", "(var, val, body)") edge_info = [ - Edge(node_to_id[node.var], node_id), - Edge(node_to_id[node.value], node_id), - Edge(node_to_id[node.body], node_id), + VizEdge(node_to_id[node.var], node_id), + VizEdge(node_to_id[node.value], node_id), + VizEdge(node_to_id[node.body], node_id), ] return node_info, edge_info def _global_var_node(self, node, node_to_id): node_id = node_to_id[node] - node_info = Node(node_id, "GlobalVar", node.name_hint) + node_info = VizNode(node_id, "GlobalVar", node.name_hint) edge_info = [] return node_info, edge_info def _if_node(self, node, node_to_id): node_id = node_to_id[node] - node_info = Node(node_id, "If", "(cond, true, false)") + node_info = VizNode(node_id, "If", "(cond, true, false)") edge_info = [ - Edge(node_to_id[node.cond], node_id), - Edge(node_to_id[node.true_branch], node_id), - Edge(node_to_id[node.false_branch], node_id), + VizEdge(node_to_id[node.cond], node_id), + VizEdge(node_to_id[node.true_branch], node_id), + VizEdge(node_to_id[node.false_branch], node_id), ] return node_info, edge_info def _op_node(self, node, node_to_id): node_id = node_to_id[node] op_name = node.name - node_info = Node(node_id, op_name, "") + node_info = VizNode(node_id, op_name, "") edge_info = [] return node_info, edge_info def _function_node(self, node, node_to_id): node_id = node_to_id[node] - node_info = Node(node_id, "Func", str(node.params)) - edge_info = [Edge(node_to_id[node.body], node_id)] + node_info = VizNode(node_id, "Func", str(node.params)) + edge_info = [VizEdge(node_to_id[node.body], node_id)] return node_info, edge_info