diff --git a/docs/reference/api/python/contrib.rst b/docs/reference/api/python/contrib.rst index 0eb3024c2d08..6cef11896753 100644 --- a/docs/reference/api/python/contrib.rst +++ b/docs/reference/api/python/contrib.rst @@ -92,6 +92,16 @@ tvm.contrib.random .. automodule:: tvm.contrib.random :members: +tvm.contrib.relay_viz +~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: tvm.contrib.relay_viz + :members: RelayVisualizer +.. autoattribute:: tvm.contrib.relay_viz.PlotterBackend.BOKEH +.. autoattribute:: tvm.contrib.relay_viz.PlotterBackend.TERMINAL +.. automodule:: tvm.contrib.relay_viz.plotter + :members: +.. automodule:: tvm.contrib.relay_viz.node_edge_gen + :members: tvm.contrib.rocblas ~~~~~~~~~~~~~~~~~~~ diff --git a/gallery/how_to/work_with_relay/using_relay_viz.py b/gallery/how_to/work_with_relay/using_relay_viz.py new file mode 100644 index 000000000000..827be8ba3c0f --- /dev/null +++ b/gallery/how_to/work_with_relay/using_relay_viz.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=line-too-long +""" +Use Relay Visualizer to Visualize Relay +============================================================ +**Author**: `Chi-Wei Wang `_ + +This is an introduction about using Relay Visualizer to visualize a Relay IR module. + +Relay IR module can contain lots of operations. Although individual +operations are usually easy to understand, they become complicated quickly +when you put them together. It could get even worse while optimiztion passes +come into play. + +This utility abstracts an IR module as graphs containing nodes and edges. +It provides a default parser to interpret an IR modules with nodes and edges. +Two renderer backends are also implemented to visualize them. + +Here we use a backend showing Relay IR module in the terminal for illustation. +It is a much more lightweight compared to another backend using `Bokeh `_. +See ``/python/tvm/contrib/relay_viz/README.md``. +Also we will introduce how to implement customized parsers and renderers through +some interfaces classes. +""" +from typing import ( + Dict, + Union, + Tuple, + List, +) +import tvm +from tvm import relay +from tvm.contrib import relay_viz +from tvm.contrib.relay_viz.node_edge_gen import ( + VizNode, + VizEdge, + NodeEdgeGenerator, +) +from tvm.contrib.relay_viz.terminal import ( + TermNodeEdgeGenerator, + TermGraph, + TermPlotter, +) + +###################################################################### +# Define a Relay IR Module with multiple GlobalVar +# ------------------------------------------------ +# Let's build an example Relay IR Module containing multiple ``GlobalVar``. +# We define an ``add`` function and call it in the main function. +data = relay.var("data") +bias = relay.var("bias") +add_op = relay.add(data, bias) +add_func = relay.Function([data, bias], add_op) +add_gvar = relay.GlobalVar("AddFunc") + +input0 = relay.var("input0") +input1 = relay.var("input1") +input2 = relay.var("input2") +add_01 = relay.Call(add_gvar, [input0, input1]) +add_012 = relay.Call(add_gvar, [input2, add_01]) +main_func = relay.Function([input0, input1, input2], add_012) +main_gvar = relay.GlobalVar("main") + +mod = tvm.IRModule({main_gvar: main_func, add_gvar: add_func}) + +###################################################################### +# Render the graph with Relay Visualizer on the terminal +# ------------------------------------------------------ +# The terminal backend can show a Relay IR module as in a text-form +# similar to `clang ast-dump `_. +# We should see ``main`` and ``AddFunc`` function. ``AddFunc`` is called twice in the ``main`` function. +viz = relay_viz.RelayVisualizer(mod, {}, relay_viz.PlotterBackend.TERMINAL) +viz.render() + +###################################################################### +# Customize Parser for Interested Relay Types +# ------------------------------------------- +# Sometimes the information shown by the default implementation is not suitable +# for a specific usage. It is possible to provide your own parser and renderer. +# Here demostrate how to customize parsers for ``relay.var``. +# We need to implement :py:class:`tvm.contrib.relay_viz.node_edge_gen.NodeEdgeGenerator` interface. +class YourAwesomeParser(NodeEdgeGenerator): + def __init__(self): + self._org_parser = TermNodeEdgeGenerator() + + def get_node_edges( + self, + node: relay.Expr, + relay_param: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.Expr, Union[int, str]], + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + + if isinstance(node, relay.Var): + node = VizNode(node_to_id[node], "AwesomeVar", f"name_hint {node.name_hint}") + # no edge is introduced. So return an empty list. + ret = (node, []) + return ret + + # delegate other types to the original parser. + return self._org_parser.get_node_edges(node, relay_param, node_to_id) + + +###################################################################### +# Pass a tuple of :py:class:`tvm.contrib.relay_viz.plotter.Plotter` and +# :py:class:`tvm.contrib.relay_viz.node_edge_gen.NodeEdgeGenerator` instances +# to ``RelayVisualizer``. Here we re-use the Plotter interface implemented inside +# ``relay_viz.terminal`` module. +viz = relay_viz.RelayVisualizer(mod, {}, (TermPlotter(), YourAwesomeParser())) +viz.render() + +###################################################################### +# More Customization around Graph and Plotter +# ------------------------------------------- +# All ``RelayVisualizer`` care about are interfaces defined in ``plotter.py`` and +# ``node_edge_generator.py``. We can override them to introduce custimized logics. +# For example, if we want the Graph to duplicate above ``AwesomeVar`` while it is added, +# we can override ``relay_viz.terminal.TermGraph.node``. +class AwesomeGraph(TermGraph): + def node(self, node_id, node_type, node_detail): + # add original node first + super().node(node_id, node_type, node_detail) + if node_type == "AwesomeVar": + duplicated_id = f"duplciated_{node_id}" + duplicated_type = "double AwesomeVar" + super().node(duplicated_id, duplicated_type, "") + # connect the duplicated var to the original one + super().edge(duplicated_id, node_id) + + +# override TermPlotter to return `AwesomeGraph` instead +class AwesomePlotter(TermPlotter): + def create_graph(self, name): + self._name_to_graph[name] = AwesomeGraph(name) + return self._name_to_graph[name] + + +viz = relay_viz.RelayVisualizer(mod, {}, (AwesomePlotter(), YourAwesomeParser())) +viz.render() + +###################################################################### +# Summary +# ------- +# This tutorial demonstrates the usage of Relay Visualizer. +# The class :py:class:`tvm.contrib.relay_viz.RelayVisualizer` is composed of interfaces +# defined in ``plotter.py`` and ``node_edge_generator.py``. It provides a single entry point +# while keeping the possibility of implementing customized visualizer in various cases. +# diff --git a/python/tvm/contrib/relay_viz/README.md b/python/tvm/contrib/relay_viz/README.md new file mode 100644 index 000000000000..c1d7d2245249 --- /dev/null +++ b/python/tvm/contrib/relay_viz/README.md @@ -0,0 +1,78 @@ + + + + + + + + + + + + + + + + + + +# IR Visualization + +This tool target to visualize Relay IR. + +# Table of Contents +1. [Requirement](#Requirement) +2. [Usage](#Usage) +3. [Credits](#Credits) +4. [Design and Customization](#Design-and-Customization) + +## Requirement + +### Terminal Backend +1. TVM + +### Bokeh Backend +1. TVM +2. graphviz +2. pydot +3. bokeh >= 2.3.1 + +``` +# To install TVM, please refer to https://tvm.apache.org/docs/install/from_source.html + +# requirements of pydot +apt-get install graphviz + +# pydot and bokeh +pip install pydot bokeh==2.3.1 +``` + +## Usage + +``` +from tvm.contrib import relay_viz +mod, params = tvm.relay.frontend.from_onnx(net, shape_dict) +vizer = relay_viz.RelayVisualizer(mod, relay_param=params, backend=PlotterBackend.BOKEH) +vizer.render("output.html") +``` + +## Credits + +1. https://github.com/apache/tvm/pull/4370 + +2. https://tvm.apache.org/2020/07/14/bert-pytorch-tvm + +3. https://discuss.tvm.apache.org/t/rfc-visualizing-relay-program-as-graph/4825/17 + +## Design and Customization + +This utility is composed of two parts: `node_edge_gen.py` and `plotter.py`. + +`plotter.py` define interfaces of `Graph` and `Plotter`. `Plotter` is responsible to render a collection of `Graph`. + +`node_edge_gen.py` define interfaces of converting Relay IR modules to nodes and edges. Further, this python module provide a default implementation for common relay types. + +If customization is wanted for a certain relay type, we can implement the `NodeEdgeGenerator` interface, handling that relay type accordingly, and delegate other types to the default implementation. See `_terminal.py` for an example usage. + +These two interfaces are glued by the top level class `RelayVisualizer`, which passes a relay module to `NodeEdgeGenerator` and add nodes and edges to `Graph`. +Then, it render the plot by calling `Plotter.render()`. diff --git a/python/tvm/contrib/relay_viz/__init__.py b/python/tvm/contrib/relay_viz/__init__.py new file mode 100644 index 000000000000..6c84a5a1fd07 --- /dev/null +++ b/python/tvm/contrib/relay_viz/__init__.py @@ -0,0 +1,151 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relay IR Visualizer""" +from typing import ( + Dict, + Tuple, + Union, +) +from enum import Enum +import tvm +from tvm import relay +from .plotter import Plotter +from .node_edge_gen import NodeEdgeGenerator + + +class PlotterBackend(Enum): + """Enumeration for available plotter backends.""" + + BOKEH = "bokeh" + TERMINAL = "terminal" + + +class RelayVisualizer: + """Relay IR Visualizer + + Parameters + ---------- + relay_mod: tvm.IRModule + Relay IR module. + relay_param: None | Dict[str, tvm.runtime.NDArray] + Relay parameter dictionary. Default `None`. + backend: PlotterBackend | Tuple[Plotter, NodeEdgeGenerator] + The backend used to render graphs. It can be a tuple of an implemented Plotter instance and + NodeEdgeGenerator instance to introduce customized parsing and visualization logics. + Default ``PlotterBackend.TERMINAL``. + """ + + def __init__( + self, + relay_mod: tvm.IRModule, + relay_param: Union[None, Dict[str, tvm.runtime.NDArray]] = None, + backend: Union[PlotterBackend, Tuple[Plotter, NodeEdgeGenerator]] = PlotterBackend.TERMINAL, + ): + + self._plotter, self._ne_generator = get_plotter_and_generator(backend) + self._relay_param = relay_param if relay_param is not None else {} + + global_vars = relay_mod.get_global_vars() + graph_names = [] + # If we have main function, put it to the first. + # Then main function can be shown on the top. + for gv_name in global_vars: + if gv_name.name_hint == "main": + graph_names.insert(0, gv_name.name_hint) + else: + graph_names.append(gv_name.name_hint) + + node_to_id = {} + + def traverse_expr(node): + if node in node_to_id: + return + node_to_id[node] = len(node_to_id) + + for name in graph_names: + node_to_id.clear() + relay.analysis.post_order_visit(relay_mod[name], traverse_expr) + graph = self._plotter.create_graph(name) + self._add_nodes(graph, node_to_id, self._relay_param) + + def _add_nodes(self, graph, node_to_id, relay_param): + """add nodes and to the graph. + + Parameters + ---------- + graph : plotter.Graph + + node_to_id : Dict[relay.expr, str | int] + + relay_param : Dict[str, tvm.runtime.NDarray] + """ + for node in node_to_id: + node_info, edge_info = self._ne_generator.get_node_edges(node, relay_param, node_to_id) + if node_info is not None: + graph.node(node_info.identity, node_info.type_str, node_info.detail) + for edge in edge_info: + graph.edge(edge.start, edge.end) + + def render(self, filename: str = None) -> None: + self._plotter.render(filename=filename) + + +def get_plotter_and_generator( + backend: Union[PlotterBackend, Tuple[Plotter, NodeEdgeGenerator]] +) -> Tuple[Plotter, NodeEdgeGenerator]: + """Specify the Plottor and its NodeEdgeGenerator + + Parameters + ---------- + backend : PlotterBackend | Tuple[Plotter, NodeEdgeGenerator] + Backend used to generate nodes/edges and render them. + """ + if isinstance(backend, (tuple, list)) and len(backend) == 2: + if not isinstance(backend[0], Plotter): + raise ValueError(f"First element should be an instance derived from {type(Plotter)}") + + if not isinstance(backend[1], NodeEdgeGenerator): + raise ValueError( + f"Second element should be an instance derived from {type(NodeEdgeGenerator)}" + ) + + return tuple(backend) + + if backend not in PlotterBackend: + raise ValueError(f"Unknown plotter backend {backend}") + + # Plotter modules are Lazy-imported to avoid they become a requirement of TVM. + # Basically we want to keep them optional. Users can choose plotters they want to install. + if backend == PlotterBackend.BOKEH: + # pylint: disable=import-outside-toplevel + from .bokeh import ( + BokehPlotter, + BokehNodeEdgeGenerator, + ) + + plotter = BokehPlotter() + ne_generator = BokehNodeEdgeGenerator() + elif backend == PlotterBackend.TERMINAL: + # pylint: disable=import-outside-toplevel + from .terminal import ( + TermPlotter, + TermNodeEdgeGenerator, + ) + + plotter = TermPlotter() + ne_generator = TermNodeEdgeGenerator() + return plotter, ne_generator diff --git a/python/tvm/contrib/relay_viz/bokeh.py b/python/tvm/contrib/relay_viz/bokeh.py new file mode 100644 index 000000000000..6ea82188463e --- /dev/null +++ b/python/tvm/contrib/relay_viz/bokeh.py @@ -0,0 +1,499 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Bokeh backend for Relay IR Visualizer.""" +import html +import functools +import logging + +import numpy as np +import pydot +from bokeh.io import output_file, save +from bokeh.models import ( + ColumnDataSource, + CustomJS, + Text, + Rect, + HoverTool, + MultiLine, + Legend, + Scatter, + Plot, + TapTool, + PanTool, + ResetTool, + WheelZoomTool, + SaveTool, +) +from bokeh.palettes import ( + d3, +) +from bokeh.layouts import column + +from .plotter import ( + Plotter, + Graph, +) + +from .node_edge_gen import DefaultNodeEdgeGenerator + +_LOGGER = logging.getLogger(__name__) + +# Use default node/edge generator +BokehNodeEdgeGenerator = DefaultNodeEdgeGenerator + + +class NodeDescriptor: + """Descriptor used by Bokeh plotter.""" + + def __init__(self, node_id, node_type, node_detail): + self._node_id = node_id + self._node_type = node_type + self._node_detail = node_detail + + @property + def node_id(self): + return self._node_id + + @property + def node_type(self): + return self._node_type + + @property + def detail(self): + return self._node_detail + + +class GraphShaper: + """Provide the bounding-box, and node location, height, width given by pydot. + To access node attributes, refer to + https://github.com/pydot/pydot/blob/90936e75462c7b0e4bb16d97c1ae7efdf04e895c/src/pydot/core.py#L537 + + To access edge attributes, refer to + https://github.com/pydot/pydot/blob/90936e75462c7b0e4bb16d97c1ae7efdf04e895c/src/pydot/core.py#L645 + + The string format `pos` in an edge follows DOT language spec: + https://graphviz.org/docs/attr-types/splineType/ + """ + + # defined by graphviz. + _PX_PER_INCH = 72 + + def __init__(self, pydot_graph, prog="dot", args=None): + if args is None: + args = [] + # call the graphviz program to get layout + pydot_graph_str = pydot_graph.create([prog] + args, format="dot").decode() + # remember original nodes + self._nodes = [n.get_name() for n in pydot_graph.get_nodes()] + # parse layout + pydot_graph = pydot.graph_from_dot_data(pydot_graph_str) + if len(pydot_graph) != 1: + # should be unlikely. + _LOGGER.warning( + "Got %d pydot graphs. Only the first one will be used.", len(pydot_graph) + ) + self._pydot_graph = pydot_graph[0] + + def get_nodes(self): + return self._nodes + + @functools.lru_cache() + def get_edge_path(self, start_node_id, end_node_id): + """Get explicit path points for MultiLine. + Parse points formating an edge. The format of points in an edge is either: + 1. e,x_point,y_point + 2. s,x_point,y_point + 3. x_point,y_point + We don't care about `e` or `s` here, so simplt parse out x_point and y_point. + """ + edge = self._pydot_graph.get_edge(str(start_node_id), str(end_node_id)) + if len(edge) != 1: + _LOGGER.warning( + "Got %d edges between %s and %s. Only the first one will be used.", + len(edge), + start_node_id, + end_node_id, + ) + edge = edge[0] + # filter out quotes and newline + pos_str = edge.get_pos().strip('"').replace("\\\n", "") + tokens = pos_str.split(" ") + s_token = None + e_token = None + ret_x_pts = [] + ret_y_pts = [] + for token in tokens: + if token.startswith("e,"): + e_token = token + elif token.startswith("s,"): + s_token = token + else: + x_str, y_str = token.split(",") + ret_x_pts.append(float(x_str)) + ret_y_pts.append(float(y_str)) + if s_token is not None: + _, x_str, y_str = s_token.split(",") + ret_x_pts.insert(0, float(x_str)) + ret_y_pts.insert(0, float(y_str)) + if e_token is not None: + _, x_str, y_str = e_token.split(",") + ret_x_pts.append(float(x_str)) + ret_y_pts.append(float(y_str)) + + return ret_x_pts, ret_y_pts + + @functools.lru_cache() + def get_node_pos(self, node_name): + pos_str = self._get_node_attr(node_name, "pos", "0,0") + return list(map(float, pos_str.split(","))) + + def get_node_height(self, node_name): + height_str = self._get_node_attr(node_name, "height", "20") + return float(height_str) * self._PX_PER_INCH + + def get_node_width(self, node_name): + width_str = self._get_node_attr(node_name, "width", "20") + return float(width_str) * self._PX_PER_INCH + + def _get_node_attr(self, node_name, attr_name, default_val): + + node = self._pydot_graph.get_node(str(node_name)) + if len(node) > 1: + _LOGGER.error( + "There are %d nodes with the name %s. Randomly choose one.", len(node), node_name + ) + if len(node) == 0: + _LOGGER.warning( + "%s does not exist in the graph. Use default %s for attribute %s", + node_name, + default_val, + attr_name, + ) + return default_val + + node = node[0] + try: + val = node.obj_dict["attributes"][attr_name].strip('"') + except KeyError: + _LOGGER.warning( + "%s don't exist in node %s. Use default %s", attr_name, node_name, default_val + ) + val = default_val + return val + + +class BokehGraph(Graph): + """Use Bokeh library to plot networks, i.e. nodes and edges.""" + + def __init__(self): + self._pydot_digraph = pydot.Dot(graph_type="digraph") + self._id_to_node = {} + + def node(self, node_id, node_type, node_detail): + # need string for pydot + node_id = str(node_id) + if node_id in self._id_to_node: + _LOGGER.warning("node_id %s already exists.", node_id) + return + self._pydot_digraph.add_node(pydot.Node(node_id, label=node_detail)) + self._id_to_node[node_id] = NodeDescriptor(node_id, node_type, node_detail) + + def edge(self, id_start, id_end): + # need string to pydot + id_start, id_end = str(id_start), str(id_end) + self._pydot_digraph.add_edge(pydot.Edge(id_start, id_end)) + + def render(self, plot): + """To draw a Bokeh Graph""" + shaper = GraphShaper( + self._pydot_digraph, + prog="dot", + args=["-Grankdir=TB", "-Gsplines=ortho", "-Gfontsize=14", "-Nordering=in"], + ) + + self._create_graph(plot, shaper) + + self._add_scalable_glyph(plot, shaper) + return plot + + def _get_type_to_color_map(self): + category20 = d3["Category20"][20] + # FIXME: a problem is, for different network we have different color + # for the same type. + all_types = list({v.node_type for v in self._id_to_node.values()}) + all_types.sort() + if len(all_types) > 20: + _LOGGER.warning( + "The number of types %d is larger than 20. " + "Some colors are re-used for different types.", + len(all_types), + ) + type_to_color = {} + for idx, t in enumerate(all_types): + type_to_color[t] = category20[idx % 20] + return type_to_color + + def _create_graph(self, plot, shaper): + + # Add edge first + edges = self._pydot_digraph.get_edges() + x_path_list = [] + y_path_list = [] + for edge in edges: + id_start = edge.get_source() + id_end = edge.get_destination() + x_pts, y_pts = shaper.get_edge_path(id_start, id_end) + x_path_list.append(x_pts) + y_path_list.append(y_pts) + + multi_line_source = ColumnDataSource({"xs": x_path_list, "ys": y_path_list}) + edge_line_color = "#888888" + edge_line_width = 3 + multi_line_glyph = MultiLine(line_color=edge_line_color, line_width=edge_line_width) + plot.add_glyph(multi_line_source, multi_line_glyph) + + # Then add nodes + type_to_color = self._get_type_to_color_map() + + def cnvt_to_html(s): + return html.escape(s).replace("\n", "
") + + label_to_ids = {} + for node_id in shaper.get_nodes(): + label = self._id_to_node[node_id].node_type + if label not in label_to_ids: + label_to_ids[label] = [] + label_to_ids[label].append(node_id) + + renderers = [] + legend_itmes = [] + for label, id_list in label_to_ids.items(): + source = ColumnDataSource( + { + "x": [shaper.get_node_pos(n)[0] for n in id_list], + "y": [shaper.get_node_pos(n)[1] for n in id_list], + "width": [shaper.get_node_width(n) for n in id_list], + "height": [shaper.get_node_height(n) for n in id_list], + "node_detail": [cnvt_to_html(self._id_to_node[n].detail) for n in id_list], + "node_type": [label] * len(id_list), + } + ) + glyph = Rect(fill_color=type_to_color[label]) + renderer = plot.add_glyph(source, glyph) + # set glyph for interactivity + renderer.nonselection_glyph = Rect(fill_color=type_to_color[label]) + renderer.hover_glyph = Rect( + fill_color=type_to_color[label], line_color="firebrick", line_width=3 + ) + renderer.selection_glyph = Rect( + fill_color=type_to_color[label], line_color="firebrick", line_width=3 + ) + # Though it is called "muted_glyph", we actually use it + # to emphasize nodes in this renderer. + renderer.muted_glyph = Rect( + fill_color=type_to_color[label], line_color="firebrick", line_width=3 + ) + name = f"{self._get_graph_name(plot)}_{label}" + renderer.name = name + renderers.append(renderer) + legend_itmes.append((label, [renderer])) + + # add legend + legend = Legend( + items=legend_itmes, + title="Click to highlight", + inactive_fill_color="firebrick", + inactive_fill_alpha=0.2, + ) + legend.click_policy = "mute" + legend.location = "top_right" + plot.add_layout(legend) + + # add tooltips + tooltips = [ + ("node_type", "@node_type"), + ("description", "@node_detail{safe}"), + ] + inspect_tool = WheelZoomTool() + # only render nodes + hover_tool = HoverTool(tooltips=tooltips, renderers=renderers) + plot.add_tools(PanTool(), TapTool(), inspect_tool, hover_tool, ResetTool(), SaveTool()) + plot.toolbar.active_scroll = inspect_tool + + def _add_scalable_glyph(self, plot, shaper): + nodes = shaper.get_nodes() + + def populate_detail(n_type, n_detail): + if n_detail: + return f"{n_type}\n{n_detail}" + return n_type + + text_source = ColumnDataSource( + { + "x": [shaper.get_node_pos(n)[0] for n in nodes], + "y": [shaper.get_node_pos(n)[1] for n in nodes], + "text": [self._id_to_node[n].node_type for n in nodes], + "detail": [ + populate_detail(self._id_to_node[n].node_type, self._id_to_node[n].detail) + for n in nodes + ], + "box_w": [shaper.get_node_width(n) for n in nodes], + "box_h": [shaper.get_node_height(n) for n in nodes], + } + ) + + text_glyph = Text( + x="x", + y="y", + text="text", + text_align="center", + text_baseline="middle", + text_font_size={"value": "14px"}, + ) + node_annotation = plot.add_glyph(text_source, text_glyph) + + def get_scatter_loc(x_start, x_end, y_start, y_end, end_node): + """return x, y, angle as a tuple""" + node_x, node_y = shaper.get_node_pos(end_node) + node_w = shaper.get_node_width(end_node) + node_h = shaper.get_node_height(end_node) + + # only 4 direction + if x_end - x_start > 0: + return node_x - node_w / 2, y_end, -np.pi / 2 + if x_end - x_start < 0: + return node_x + node_w / 2, y_end, np.pi / 2 + if y_end - y_start < 0: + return x_end, node_y + node_h / 2, np.pi + return x_end, node_y - node_h / 2, 0 + + scatter_source = {"x": [], "y": [], "angle": []} + for edge in self._pydot_digraph.get_edges(): + id_start = edge.get_source() + id_end = edge.get_destination() + x_pts, y_pts = shaper.get_edge_path(id_start, id_end) + x_loc, y_loc, angle = get_scatter_loc( + x_pts[-2], x_pts[-1], y_pts[-2], y_pts[-1], id_end + ) + scatter_source["angle"].append(angle) + scatter_source["x"].append(x_loc) + scatter_source["y"].append(y_loc) + + scatter_glyph = Scatter( + x="x", + y="y", + angle="angle", + size=5, + marker="triangle", + fill_color="#AAAAAA", + fill_alpha=0.8, + ) + edge_end_arrow = plot.add_glyph(ColumnDataSource(scatter_source), scatter_glyph) + + plot.y_range.js_on_change( + "start", + CustomJS( + args=dict( + plot=plot, + node_annotation=node_annotation, + text_source=text_source, + edge_end_arrow=edge_end_arrow, + ), + code=""" + // fontsize is in px + var fontsize = 14 + // ratio = data_point/px + var ratio = (this.end - this.start)/plot.height + var text_list = text_source.data["text"] + var detail_list = text_source.data["detail"] + var box_h_list = text_source.data["box_h"] + for(var i = 0; i < text_list.length; i++) { + var line_num = Math.floor((box_h_list[i]/ratio) / (fontsize*1.5)) + if(line_num <= 0) { + // relieve for the first line + if(Math.floor((box_h_list[i]/ratio) / (fontsize)) > 0) { + line_num = 1 + } + } + var lines = detail_list[i].split("\\n") + lines = lines.slice(0, line_num) + text_list[i] = lines.join("\\n") + } + text_source.change.emit() + + node_annotation.glyph.text_font_size = {value: `${fontsize}px`} + + var new_scatter_size = Math.round(fontsize / ratio) + edge_end_arrow.glyph.size = {value: new_scatter_size} + """, + ), + ) + + @staticmethod + def _get_graph_name(plot): + return plot.title + + +class BokehPlotter(Plotter): + """Render and save collections of class BokehGraph.""" + + def __init__(self): + self._name_to_graph = {} + + def create_graph(self, name): + if name in self._name_to_graph: + _LOGGER.warning("Graph name %s already exists.", name) + else: + self._name_to_graph[name] = BokehGraph() + return self._name_to_graph[name] + + def render(self, filename): + if filename is None: + filename = "bokeh_plotter_result.html" + elif not filename.endswith(".html"): + filename = f"{filename}.html" + + dom_list = [] + for name, graph in self._name_to_graph.items(): + plot = Plot( + title=name, + width=1600, + height=900, + align="center", + margin=(0, 0, 0, 70), + ) + + dom = graph.render(plot) + dom_list.append(dom) + + self._save_html(filename, column(*dom_list)) + + def _save_html(self, filename, layout_dom): + + output_file(filename, title=filename) + + template = """ + {% block postamble %} + + {% endblock %} + """ + + save(layout_dom, filename=filename, title=filename, template=template) diff --git a/python/tvm/contrib/relay_viz/node_edge_gen.py b/python/tvm/contrib/relay_viz/node_edge_gen.py new file mode 100644 index 000000000000..cc60266a4c2f --- /dev/null +++ b/python/tvm/contrib/relay_viz/node_edge_gen.py @@ -0,0 +1,268 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""NodeEdgeGenerator interface for :py:class:`tvm.contrib.relay_viz.plotter.Graph`.""" +import abc +from typing import ( + Dict, + Union, + Tuple, + List, +) +import tvm +from tvm import relay + +UNKNOWN_TYPE = "unknown" + + +class VizNode: + """Node carry information used by `plotter.Graph` interface. + + Parameters + ---------- + node_id: int | str + Unique identifier for this node. + node_type: str + Type of this node. + node_detail: str + Any supplement for this node such as attributes. + """ + + def __init__(self, node_id: Union[int, str], node_type: str, node_detail: str): + self._id = node_id + self._type = node_type + self._detail = node_detail + + @property + def identity(self) -> Union[int, str]: + return self._id + + @property + def type_str(self) -> str: + return self._type + + @property + def detail(self) -> str: + return self._detail + + +class VizEdge: + """Edges for `plotter.Graph` interface. + + Parameters + ---------- + start_node: int | str + The identifier of the node starting the edge. + end_node: int | str + The identifier of the node ending the edge. + """ + + def __init__(self, start_node: Union[int, str], end_node: Union[int, str]): + self._start_node = start_node + self._end_node = end_node + + @property + def start(self) -> Union[int, str]: + return self._start_node + + @property + def end(self) -> Union[int, str]: + return self._end_node + + +class NodeEdgeGenerator(abc.ABC): + """An interface class to generate nodes and edges information for Graph interfaces.""" + + @abc.abstractmethod + def get_node_edges( + self, + node: relay.Expr, + relay_param: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.Expr, Union[int, str]], + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + """Generate node and edges consumed by Graph interfaces. + + Parameters + ---------- + node : relay.Expr + relay.Expr which will be parsed and generate a node and edges. + + relay_param: Dict[str, tvm.runtime.NDArray] + relay parameters dictionary. + + node_to_id : Dict[relay.Expr, Union[int, str]] + a mapping from relay.Expr to node id which should be unique. + + Returns + ------- + rv1 : Union[VizNode, None] + VizNode represent the relay.Expr. If the relay.Expr is not intended to introduce a node + to the graph, return None. + + rv2 : List[VizEdge] + a list of VizEdge to describe the connectivity of the relay.Expr. + Can be empty list to indicate no connectivity. + """ + + +class DefaultNodeEdgeGenerator(NodeEdgeGenerator): + """NodeEdgeGenerator generate for nodes and edges consumed by Graph. + This class is a default implementation for common relay types, heavily based on + `visualize` function in https://tvm.apache.org/2020/07/14/bert-pytorch-tvm + """ + + def __init__(self): + self._render_rules = {} + self._build_rules() + + def get_node_edges( + self, + node: relay.Expr, + relay_param: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.Expr, Union[int, str]], + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + try: + node_info, edge_info = self._render_rules[type(node)](node, relay_param, node_to_id) + except KeyError: + node_info = VizNode( + node_to_id[node], UNKNOWN_TYPE, f"don't know how to parse {type(node)}" + ) + edge_info = [] + return node_info, edge_info + + def _var_node( + self, + node: relay.Expr, + relay_param: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.Expr, Union[int, str]], + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + """Render rule for a relay var node""" + node_id = node_to_id[node] + name_hint = node.name_hint + node_detail = f"name_hint: {name_hint}" + node_type = "Var(Param)" if name_hint in relay_param else "Var(Input)" + if node.type_annotation is not None: + if hasattr(node.type_annotation, "shape"): + shape = tuple(map(int, node.type_annotation.shape)) + dtype = node.type_annotation.dtype + node_detail = f"name_hint: {name_hint}\nshape: {shape}\ndtype: {dtype}" + else: + node_detail = f"name_hint: {name_hint}\ntype_annotation: {node.type_annotation}" + node_info = VizNode(node_id, node_type, node_detail) + edge_info = [] + return node_info, edge_info + + def _function_node( + self, + node: relay.Expr, + _: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.Expr, Union[int, str]], + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + """Render rule for a relay function node""" + node_details = [] + name = "" + func_attrs = node.attrs + if func_attrs: + node_details = [f"{k}: {func_attrs.get_str(k)}" for k in func_attrs.keys()] + # "Composite" might from relay.transform.MergeComposite + if "Composite" in func_attrs.keys(): + name = func_attrs["Composite"] + node_id = node_to_id[node] + node_info = VizNode(node_id, f"Func {name}", "\n".join(node_details)) + edge_info = [VizEdge(node_to_id[node.body], node_id)] + return node_info, edge_info + + def _call_node( + self, + node: relay.Expr, + _: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.Expr, Union[int, str]], + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + """Render rule for a relay call node""" + node_id = node_to_id[node] + op_name = UNKNOWN_TYPE + node_detail = [] + if isinstance(node.op, tvm.ir.Op): + op_name = node.op.name + if node.attrs: + node_detail = [f"{k}: {node.attrs.get_str(k)}" for k in node.attrs.keys()] + elif isinstance(node.op, relay.Function): + func_attrs = node.op.attrs + op_name = "Anonymous Func" + if func_attrs: + node_detail = [f"{k}: {func_attrs.get_str(k)}" for k in func_attrs.keys()] + # "Composite" might from relay.transform.MergeComposite + if "Composite" in func_attrs.keys(): + op_name = func_attrs["Composite"] + elif isinstance(node.op, relay.GlobalVar): + op_name = "GlobalVar" + node_detail = [f"GlobalVar.name_hint: {node.op.name_hint}"] + else: + op_name = str(type(node.op)).split(".")[-1].split("'")[0] + + node_info = VizNode(node_id, f"Call {op_name}", "\n".join(node_detail)) + args = [node_to_id[arg] for arg in node.args] + edge_info = [VizEdge(arg, node_id) for arg in args] + return node_info, edge_info + + def _tuple_node( + self, + node: relay.Expr, + _: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.Expr, Union[int, str]], + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + node_id = node_to_id[node] + node_info = VizNode(node_id, "Tuple", "") + edge_info = [VizEdge(node_to_id[field], node_id) for field in node.fields] + return node_info, edge_info + + def _tuple_get_item_node( + self, + node: relay.Expr, + _: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.Expr, Union[int, str]], + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + node_id = node_to_id[node] + node_info = VizNode(node_id, f"TupleGetItem", "idx: {node.index}") + edge_info = [VizEdge(node_to_id[node.tuple_value], node_id)] + return node_info, edge_info + + def _constant_node( + self, + node: relay.Expr, + _: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.Expr, Union[int, str]], + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + node_id = node_to_id[node] + node_detail = f"shape: {node.data.shape}, dtype: {node.data.dtype}" + node_info = VizNode(node_id, "Const", node_detail) + edge_info = [] + return node_info, edge_info + + def _null(self, *_) -> Tuple[None, List[VizEdge]]: + return None, [] + + def _build_rules(self): + self._render_rules = { + tvm.relay.Function: self._function_node, + tvm.relay.expr.Call: self._call_node, + tvm.relay.expr.Var: self._var_node, + tvm.relay.expr.Tuple: self._tuple_node, + tvm.relay.expr.TupleGetItem: self._tuple_get_item_node, + tvm.relay.expr.Constant: self._constant_node, + tvm.relay.expr.GlobalVar: self._null, + tvm.ir.Op: self._null, + } diff --git a/python/tvm/contrib/relay_viz/plotter.py b/python/tvm/contrib/relay_viz/plotter.py new file mode 100644 index 000000000000..de8c24c39a40 --- /dev/null +++ b/python/tvm/contrib/relay_viz/plotter.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Abstract class used by :py:class:`tvm.contrib.relay_viz.RelayVisualizer`.""" +import abc +from typing import Union + + +class Graph(abc.ABC): + """Abstract class for graph, which is composed of nodes and edges.""" + + @abc.abstractmethod + def node(self, node_id: Union[int, str], node_type: str, node_detail: str) -> None: + """Add a node to the underlying graph. + + Parameters + ---------- + node_id : Union[int, str] + Serve as the ID to the node. + + node_type : str + the type of the node. + + node_detail : str + the description of the node. + """ + + @abc.abstractmethod + def edge(self, id_start: Union[int, str], id_end: Union[int, str]) -> None: + """Add an edge to the underlying graph. + + Parameters + ---------- + id_start : Union[int, str] + the ID to the starting node. + + id_end : Union[int, str] + the ID to the ending node. + """ + + +class Plotter(abc.ABC): + """Abstract class for plotters, rendering a collection of Graph interface.""" + + @abc.abstractmethod + def create_graph(self, name: str) -> Graph: + """Create a graph + + Parameters + ---------- + name : str + the name of the graph + + Return + ------ + rv1: class Graph + """ + + @abc.abstractmethod + def render(self, filename: str) -> None: + """Render the graph as a file. + + Parameters + ---------- + filename : str + see the definition of implemented class. + """ diff --git a/python/tvm/contrib/relay_viz/terminal.py b/python/tvm/contrib/relay_viz/terminal.py new file mode 100644 index 000000000000..82332d964907 --- /dev/null +++ b/python/tvm/contrib/relay_viz/terminal.py @@ -0,0 +1,208 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Visualize Relay IR in AST text-form.""" + +from collections import deque +from typing import ( + Dict, + Union, + Tuple, + List, +) + +import tvm +from tvm import relay + +from .plotter import ( + Plotter, + Graph, +) + +from .node_edge_gen import ( + VizNode, + VizEdge, + NodeEdgeGenerator, + DefaultNodeEdgeGenerator, +) + + +class TermNodeEdgeGenerator(NodeEdgeGenerator): + """Terminal nodes and edges generator.""" + + def __init__(self): + self._default_ne_gen = DefaultNodeEdgeGenerator() + + def get_node_edges( + self, + node: relay.Expr, + relay_param: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.Expr, Union[int, str]], + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + """Generate node and edges consumed by TermGraph interfaces""" + if isinstance(node, relay.Call): + return self._call_node(node, node_to_id) + + if isinstance(node, relay.Let): + return self._let_node(node, node_to_id) + + if isinstance(node, relay.GlobalVar): + return self._global_var_node(node, node_to_id) + + if isinstance(node, relay.If): + return self._if_node(node, node_to_id) + + if isinstance(node, tvm.ir.Op): + return self._op_node(node, node_to_id) + + if isinstance(node, relay.Function): + return self._function_node(node, node_to_id) + + # otherwise, delegate to the default implementation + return self._default_ne_gen.get_node_edges(node, relay_param, node_to_id) + + def _call_node(self, node, node_to_id): + node_id = node_to_id[node] + node_info = VizNode(node_id, "Call", "") + edge_info = [VizEdge(node_to_id[node.op], node_id)] + for arg in node.args: + arg_nid = node_to_id[arg] + edge_info.append(VizEdge(arg_nid, node_id)) + return node_info, edge_info + + def _let_node(self, node, node_to_id): + node_id = node_to_id[node] + node_info = VizNode(node_id, "Let", "(var, val, body)") + edge_info = [ + VizEdge(node_to_id[node.var], node_id), + VizEdge(node_to_id[node.value], node_id), + VizEdge(node_to_id[node.body], node_id), + ] + return node_info, edge_info + + def _global_var_node(self, node, node_to_id): + node_id = node_to_id[node] + node_info = VizNode(node_id, "GlobalVar", node.name_hint) + edge_info = [] + return node_info, edge_info + + def _if_node(self, node, node_to_id): + node_id = node_to_id[node] + node_info = VizNode(node_id, "If", "(cond, true, false)") + edge_info = [ + VizEdge(node_to_id[node.cond], node_id), + VizEdge(node_to_id[node.true_branch], node_id), + VizEdge(node_to_id[node.false_branch], node_id), + ] + return node_info, edge_info + + def _op_node(self, node, node_to_id): + node_id = node_to_id[node] + op_name = node.name + node_info = VizNode(node_id, op_name, "") + edge_info = [] + return node_info, edge_info + + def _function_node(self, node, node_to_id): + node_id = node_to_id[node] + node_info = VizNode(node_id, "Func", str(node.params)) + edge_info = [VizEdge(node_to_id[node.body], node_id)] + return node_info, edge_info + + +class TermNode: + def __init__(self, node_type, other_info): + self.type = node_type + self.other_info = other_info.replace("\n", ", ") + + +class TermGraph(Graph): + """Terminal plot for a relay IR Module""" + + def __init__(self, name): + # node_id: [ connected node_id] + self._name = name + self._graph = {} + self._id_to_node = {} + # reversed post order + self._node_ids_rpo = deque() + + def node(self, node_id, node_type, node_detail): + # actually we just need the last one. + self._node_ids_rpo.appendleft(node_id) + + if node_id not in self._graph: + self._graph[node_id] = [] + + node = TermNode(node_type, node_detail) + self._id_to_node[node_id] = node + + def edge(self, id_start, id_end): + # want reserved post-order + if id_end in self._graph: + self._graph[id_end].append(id_start) + else: + self._graph[id_end] = [id_start] + + def render(self): + """To draw a terminal graph""" + lines = [] + seen_node = set() + + def gen_line(indent, n_id): + if (indent, n_id) in seen_node: + return + seen_node.add((indent, n_id)) + + conn_symbol = ["|--", "`--"] + last = len(self._graph[n_id]) - 1 + for i, next_n_id in enumerate(self._graph[n_id]): + node = self._id_to_node[next_n_id] + lines.append( + f"{indent}{conn_symbol[1 if i==last else 0]}{node.type} {node.other_info}" + ) + next_indent = indent + next_indent += " " if (i == last) else "| " + gen_line(next_indent, next_n_id) + + first_node_id = self._node_ids_rpo[0] + node = self._id_to_node[first_node_id] + lines.append(f"@{self._name}({node.other_info})") + gen_line("", first_node_id) + + return "\n".join(lines) + + +class TermPlotter(Plotter): + """Terminal plotter""" + + def __init__(self): + self._name_to_graph = {} + + def create_graph(self, name): + self._name_to_graph[name] = TermGraph(name) + return self._name_to_graph[name] + + def render(self, filename): + """If filename is None, print to stdio. Otherwise, write to the filename.""" + lines = [] + for name in self._name_to_graph: + lines.append(self._name_to_graph[name].render()) + if filename is None: + print("\n".join(lines)) + else: + with open(filename, "w") as out_file: + out_file.write("\n".join(lines)) diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 5540c35a8f7e..b467c58fa8f7 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -402,7 +402,6 @@ TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU) // line break TVM_REGISTER_TARGET_KIND("composite", kDLCPU).add_attr_option>("devices"); - /********** Registry **********/ TVM_REGISTER_GLOBAL("target.ListTargetKinds").set_body_typed(TargetKindRegEntry::ListTargetKinds); diff --git a/tests/python/contrib/test_relay_viz.py b/tests/python/contrib/test_relay_viz.py new file mode 100644 index 000000000000..3e5b73eae0ec --- /dev/null +++ b/tests/python/contrib/test_relay_viz.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relay +from tvm.contrib.relay_viz import node_edge_gen +from tvm.contrib.relay_viz.node_edge_gen import DefaultNodeEdgeGenerator + +# the testing focus on that DefaultNodeEdgeGenerator can +# parse Relay IR properly. + + +def test_var(): + ne_gen = DefaultNodeEdgeGenerator() + shape = (10, 10) + input_var = relay.var("input", shape=shape) + node, edges = ne_gen.get_node_edges(input_var, {}, {input_var: 1}) + assert node.identity == 1, "node_id should be 1." + assert "input" in node.detail, "detail should have name_hint." + assert str(shape) in node.detail, "detail should have shape." + assert len(edges) == 0, "relay.var doesn't cause any edge." + + +def test_function(): + ne_gen = DefaultNodeEdgeGenerator() + input_var = relay.var("input") + bias_var = relay.var("bias") + add_bias = relay.add(input_var, bias_var) + func = relay.Function([input_var, bias_var], add_bias) + node, edges = ne_gen.get_node_edges(func, {}, {func: 99, add_bias: 199}) + assert node.identity == 99, "node_id should be 99." + assert edges[0].start == 199, "edge.start should be node 199." + assert edges[0].end == 99, "edge.end should be node 99." + + +def test_call(): + ne_gen = DefaultNodeEdgeGenerator() + input_var = relay.var("input") + bias_var = relay.var("bias") + add_bias = relay.add(input_var, bias_var) + node, edges = ne_gen.get_node_edges(add_bias, {}, {add_bias: 1, input_var: 0, bias_var: 2}) + assert "add" in node.type_str, "node_type shuold contain op_name." + assert len(edges) == 2, "the length of edges should be 2, from two var to relay.add." + + +def test_tuple(): + ne_gen = DefaultNodeEdgeGenerator() + elemt0_var = relay.var("elemt0") + elemt1_var = relay.var("elemt1") + tup = relay.Tuple([elemt0_var, elemt1_var]) + node, edges = ne_gen.get_node_edges(tup, {}, {tup: 123, elemt0_var: 0, elemt1_var: 1}) + assert node.identity == 123, "node_id should be 123." + assert len(edges) == 2, "the length of edges should be 2, from two relay.var to tuple." + assert edges[0].start == 0 and edges[0].end == 123, "edges[0] should be 0 -> 123." + assert edges[1].start == 1 and edges[1].end == 123, "edges[1] should be 1 -> 123." + + +def test_constant(): + ne_gen = DefaultNodeEdgeGenerator() + arr = tvm.nd.array(10) + const = relay.Constant(arr) + node, edges = ne_gen.get_node_edges(const, {}, {const: 999}) + assert node.identity == 999, "node_id should be 999." + assert len(edges) == 0, "constant should not cause edges." + + arr = tvm.nd.array([[10, 11]]) + const = relay.Constant(arr) + node, edges = ne_gen.get_node_edges(const, {}, {const: 111}) + assert str(const.data.shape) in node.detail, "node_detail should contain shape." + + +if __name__ == "__main__": + test_var() + test_function() + test_call() + test_tuple() + test_constant()