diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 2bcb2b0ef7f8c..72e004b86853f 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -36,6 +36,9 @@ import tvm.ir from tvm import relay +from tvm.relay import transform +from tvm.relay.expr import GlobalVar +from tvm.relay.expr_functor import ExprMutator, ExprVisitor from ... import _ffi_api from ...dataflow_pattern import wildcard, is_op @@ -83,7 +86,6 @@ def _func_wrapper(expr): _register_external_op_helper("log") _register_external_op_helper("sqrt") _register_external_op_helper("round") -_register_external_op_helper("logsumexp") _register_external_op_helper("nn.relu") _register_external_op_helper("nn.leaky_relu") _register_external_op_helper("tanh") @@ -411,3 +413,95 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type): if conv_type == "Conv2DTranspose": return relay.nn.conv2d_transpose(data, weight, **new_attrs) return relay.nn.conv3d_transpose(data, weight, **new_attrs) + + +class IsComputeIntensiveGraph(ExprVisitor): + """ + Visits the Graph recursively and checks if it contains compute heavy ops like convolutions and + its transpose and dense. + """ + + def __init__(self): + ExprVisitor.__init__(self) + self.is_compute_intensive = False + + def visit_call(self, call): + compute_intensive_ops = set( + [ + "nn.conv1d", + "nn.conv2d", + "nn.conv2d_transpose", + "nn.conv3d", + "nn.conv3d_transpose", + "nn.dense", + ] + ) + if isinstance(call.op, tvm.tir.op.Op): + if str(call.op) in compute_intensive_ops: + self.is_compute_intensive = True + + return super().visit_call(call) + + def is_graph_compute_intensive(self, subgraph) -> bool: + """ + This function recursively visits the graph and checks if it's compute intensive" + """ + self.visit(subgraph) + return self.is_compute_intensive + + +def is_valid_subgraph(body): + """Final check on whether the subgraph is valid and should be offloaded to DNNL.""" + return IsComputeIntensiveGraph().is_graph_compute_intensive(body) + + +def prune_dnnl_subgraphs(mod): + """ + Removes invalid subgraphs, which does not contain compute intensive dnnl ops. + """ + + class SubgraphRemover(ExprMutator): + """ + Reverts subgraphs in subgraphs_to_remove back to TVM instead of using an external codegen. + """ + + def __init__(self, subgraphs_to_remove, mod, new_mod): + ExprMutator.__init__(self) + self.subgraphs_to_remove = subgraphs_to_remove + self.mod = mod + self.new_mod = new_mod + + def visit_call(self, call): + if isinstance(call.op, GlobalVar): + name = call.op.name_hint + if name in self.subgraphs_to_remove: + # "Inline" the subgraph back into new main function. + func = self.mod[name] + var_map = {} + for arg, param in zip(call.args, func.params): + var_map[param] = super().visit(arg) + new_body = relay.bind(func.body, var_map) + return new_body + if name != "main": + args = [] + for arg in call.args: + args.append(super().visit(arg)) + return call.op(*args) + return super().visit_call(call) + + subgraphs_to_remove = [] + # If only one subgraph, do nothing. + if len(mod.get_global_vars()) <= 2: + return mod + # Remove invalid subgraphs + for subgraph in mod.get_global_vars(): + name = subgraph.name_hint + if not mod[name].attrs or mod[name].attrs["Compiler"] != "dnnl": + continue + if not is_valid_subgraph(mod[name].body): + subgraphs_to_remove.append(name) + # Create new pruned module + new_mod = tvm.IRModule(mod.functions, mod.type_definitions) + new_mod["main"] = SubgraphRemover(subgraphs_to_remove, mod, new_mod).visit(mod["main"]) + new_mod = transform.RemoveUnusedFunctions()(new_mod) + return new_mod diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index fb48e05c4d806..8ddda578b3a28 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -103,6 +103,7 @@ def partition_for_dnnl(mod, params=None, alter_layout=True): ) with tvm.transform.PassContext(opt_level=3): mod = byoc_seq(mod) + mod = dnnl.prune_dnnl_subgraphs(mod) return mod @@ -123,12 +124,15 @@ def assert_result_dict_holds(result_dict): tvm.testing.assert_allclose(r1, r2, rtol=1e-3, atol=1e-3) -def run_and_verify(mod, input, params, target, run_module): - def check_dnnl_used(mod): +def run_and_verify(mod, input, params, target, run_module, subgraph_num=None): + def check_dnnl_used(mod, subgraph_num=None): num_dnnl_subgraphs = sum( [1 if "dnnl" in gv.name_hint else 0 for gv in mod.get_global_vars()] ) - assert num_dnnl_subgraphs >= 1 + if subgraph_num: + assert num_dnnl_subgraphs == subgraph_num + else: + assert num_dnnl_subgraphs >= 1 dev = tvm.cpu() result_dict = dict() @@ -137,7 +141,7 @@ def check_dnnl_used(mod): result_key = mode + ("_dnnl" if use_dnnl else "") + ("_layout" if alter_layout else "") if use_dnnl: processed_mod = partition_for_dnnl(mod, params, alter_layout) - check_dnnl_used(processed_mod) + check_dnnl_used(processed_mod, subgraph_num) else: processed_mod = mod with tvm.transform.PassContext(opt_level=3): @@ -154,7 +158,7 @@ def check_dnnl_used(mod): assert_result_dict_holds(result_dict) -def run_and_verify_func(config, run_module, target="llvm", dtype="float32"): +def run_and_verify_func(config, run_module, subgraph_num=None, target="llvm", dtype="float32"): """Test a Relay func by compiling, running, and comparing TVM and DNNL outputs. Parameters ---------- @@ -171,7 +175,9 @@ def run_and_verify_func(config, run_module, target="llvm", dtype="float32"): for k, v in input_shapes.items() if k not in is_param } - run_and_verify(f, input_dict, params, target=target, run_module=run_module) + run_and_verify( + f, input_dict, params, subgraph_num=subgraph_num, target=target, run_module=run_module + ) def get_conv1d( @@ -574,7 +580,6 @@ def get_graph(op, x_shape=(1, 8, 3, 3)): relay.log, relay.sqrt, relay.round, - relay.logsumexp, relay.nn.relu, relay.tanh, relay.sigmoid, @@ -935,6 +940,38 @@ def get_graph( run_and_verify_func(get_graph(relay.nn.max_pool3d, strides=(1, 1, 1)), run_module=run_module) +def test_prune_dnnl_subgraph(run_module): + """In this test, OP "add" should be offloaded from dnnl codegen.""" + + def get_graph(): + x1 = relay.var("x1", shape=(1, 64, 56, 56)) + x2 = relay.var("x2", shape=(1, 64, 56, 56)) + bias = relay.var("bias", shape=(64,)) + weight = relay.var("weight", shape=(64, 64, 3, 3)) + y = relay.nn.conv2d( + x1, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + ) + y = relay.nn.bias_add(y, bias) + y = relay.nn.relu(y) + y = relay.nn.global_max_pool2d(y) + y = relay.add(y, x2) + dic = { + "x1": (1, 64, 56, 56), + "x2": (1, 64, 56, 56), + "weight": (64, 64, 3, 3), + "bias": (64,), + } + param_lst = ["weight", "bias"] + out = tvm.IRModule.from_expr(y) + return out, dic, param_lst + + run_and_verify_func(get_graph(), subgraph_num=1, run_module=run_module) + + if __name__ == "__main__": import sys