Skip to content

Commit

Permalink
prune dnnl subgraph, and add related test case. (apache#10835)
Browse files Browse the repository at this point in the history
  • Loading branch information
crazydemo authored and pfk-beta committed Apr 11, 2022
1 parent bb43e96 commit 5436907
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 8 deletions.
96 changes: 95 additions & 1 deletion python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@

import tvm.ir
from tvm import relay
from tvm.relay import transform
from tvm.relay.expr import GlobalVar
from tvm.relay.expr_functor import ExprMutator, ExprVisitor

from ... import _ffi_api
from ...dataflow_pattern import wildcard, is_op
Expand Down Expand Up @@ -83,7 +86,6 @@ def _func_wrapper(expr):
_register_external_op_helper("log")
_register_external_op_helper("sqrt")
_register_external_op_helper("round")
_register_external_op_helper("logsumexp")
_register_external_op_helper("nn.relu")
_register_external_op_helper("nn.leaky_relu")
_register_external_op_helper("tanh")
Expand Down Expand Up @@ -411,3 +413,95 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type):
if conv_type == "Conv2DTranspose":
return relay.nn.conv2d_transpose(data, weight, **new_attrs)
return relay.nn.conv3d_transpose(data, weight, **new_attrs)


class IsComputeIntensiveGraph(ExprVisitor):
"""
Visits the Graph recursively and checks if it contains compute heavy ops like convolutions and
its transpose and dense.
"""

def __init__(self):
ExprVisitor.__init__(self)
self.is_compute_intensive = False

def visit_call(self, call):
compute_intensive_ops = set(
[
"nn.conv1d",
"nn.conv2d",
"nn.conv2d_transpose",
"nn.conv3d",
"nn.conv3d_transpose",
"nn.dense",
]
)
if isinstance(call.op, tvm.tir.op.Op):
if str(call.op) in compute_intensive_ops:
self.is_compute_intensive = True

return super().visit_call(call)

def is_graph_compute_intensive(self, subgraph) -> bool:
"""
This function recursively visits the graph and checks if it's compute intensive"
"""
self.visit(subgraph)
return self.is_compute_intensive


def is_valid_subgraph(body):
"""Final check on whether the subgraph is valid and should be offloaded to DNNL."""
return IsComputeIntensiveGraph().is_graph_compute_intensive(body)


def prune_dnnl_subgraphs(mod):
"""
Removes invalid subgraphs, which does not contain compute intensive dnnl ops.
"""

class SubgraphRemover(ExprMutator):
"""
Reverts subgraphs in subgraphs_to_remove back to TVM instead of using an external codegen.
"""

def __init__(self, subgraphs_to_remove, mod, new_mod):
ExprMutator.__init__(self)
self.subgraphs_to_remove = subgraphs_to_remove
self.mod = mod
self.new_mod = new_mod

def visit_call(self, call):
if isinstance(call.op, GlobalVar):
name = call.op.name_hint
if name in self.subgraphs_to_remove:
# "Inline" the subgraph back into new main function.
func = self.mod[name]
var_map = {}
for arg, param in zip(call.args, func.params):
var_map[param] = super().visit(arg)
new_body = relay.bind(func.body, var_map)
return new_body
if name != "main":
args = []
for arg in call.args:
args.append(super().visit(arg))
return call.op(*args)
return super().visit_call(call)

subgraphs_to_remove = []
# If only one subgraph, do nothing.
if len(mod.get_global_vars()) <= 2:
return mod
# Remove invalid subgraphs
for subgraph in mod.get_global_vars():
name = subgraph.name_hint
if not mod[name].attrs or mod[name].attrs["Compiler"] != "dnnl":
continue
if not is_valid_subgraph(mod[name].body):
subgraphs_to_remove.append(name)
# Create new pruned module
new_mod = tvm.IRModule(mod.functions, mod.type_definitions)
new_mod["main"] = SubgraphRemover(subgraphs_to_remove, mod, new_mod).visit(mod["main"])
new_mod = transform.RemoveUnusedFunctions()(new_mod)
return new_mod
51 changes: 44 additions & 7 deletions tests/python/contrib/test_dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def partition_for_dnnl(mod, params=None, alter_layout=True):
)
with tvm.transform.PassContext(opt_level=3):
mod = byoc_seq(mod)
mod = dnnl.prune_dnnl_subgraphs(mod)
return mod


Expand All @@ -123,12 +124,15 @@ def assert_result_dict_holds(result_dict):
tvm.testing.assert_allclose(r1, r2, rtol=1e-3, atol=1e-3)


def run_and_verify(mod, input, params, target, run_module):
def check_dnnl_used(mod):
def run_and_verify(mod, input, params, target, run_module, subgraph_num=None):
def check_dnnl_used(mod, subgraph_num=None):
num_dnnl_subgraphs = sum(
[1 if "dnnl" in gv.name_hint else 0 for gv in mod.get_global_vars()]
)
assert num_dnnl_subgraphs >= 1
if subgraph_num:
assert num_dnnl_subgraphs == subgraph_num
else:
assert num_dnnl_subgraphs >= 1

dev = tvm.cpu()
result_dict = dict()
Expand All @@ -137,7 +141,7 @@ def check_dnnl_used(mod):
result_key = mode + ("_dnnl" if use_dnnl else "") + ("_layout" if alter_layout else "")
if use_dnnl:
processed_mod = partition_for_dnnl(mod, params, alter_layout)
check_dnnl_used(processed_mod)
check_dnnl_used(processed_mod, subgraph_num)
else:
processed_mod = mod
with tvm.transform.PassContext(opt_level=3):
Expand All @@ -154,7 +158,7 @@ def check_dnnl_used(mod):
assert_result_dict_holds(result_dict)


def run_and_verify_func(config, run_module, target="llvm", dtype="float32"):
def run_and_verify_func(config, run_module, subgraph_num=None, target="llvm", dtype="float32"):
"""Test a Relay func by compiling, running, and comparing TVM and DNNL outputs.
Parameters
----------
Expand All @@ -171,7 +175,9 @@ def run_and_verify_func(config, run_module, target="llvm", dtype="float32"):
for k, v in input_shapes.items()
if k not in is_param
}
run_and_verify(f, input_dict, params, target=target, run_module=run_module)
run_and_verify(
f, input_dict, params, subgraph_num=subgraph_num, target=target, run_module=run_module
)


def get_conv1d(
Expand Down Expand Up @@ -574,7 +580,6 @@ def get_graph(op, x_shape=(1, 8, 3, 3)):
relay.log,
relay.sqrt,
relay.round,
relay.logsumexp,
relay.nn.relu,
relay.tanh,
relay.sigmoid,
Expand Down Expand Up @@ -935,6 +940,38 @@ def get_graph(
run_and_verify_func(get_graph(relay.nn.max_pool3d, strides=(1, 1, 1)), run_module=run_module)


def test_prune_dnnl_subgraph(run_module):
"""In this test, OP "add" should be offloaded from dnnl codegen."""

def get_graph():
x1 = relay.var("x1", shape=(1, 64, 56, 56))
x2 = relay.var("x2", shape=(1, 64, 56, 56))
bias = relay.var("bias", shape=(64,))
weight = relay.var("weight", shape=(64, 64, 3, 3))
y = relay.nn.conv2d(
x1,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
)
y = relay.nn.bias_add(y, bias)
y = relay.nn.relu(y)
y = relay.nn.global_max_pool2d(y)
y = relay.add(y, x2)
dic = {
"x1": (1, 64, 56, 56),
"x2": (1, 64, 56, 56),
"weight": (64, 64, 3, 3),
"bias": (64,),
}
param_lst = ["weight", "bias"]
out = tvm.IRModule.from_expr(y)
return out, dic, param_lst

run_and_verify_func(get_graph(), subgraph_num=1, run_module=run_module)


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit 5436907

Please sign in to comment.