From 9ae192909e7d87e059685df685f17276225a4091 Mon Sep 17 00:00:00 2001 From: Archermmt Date: Sun, 6 Aug 2023 16:39:51 +0800 Subject: [PATCH 1/8] add test for set expr name --- CMakeLists.txt | 2 + cmake/config.cmake | 3 + cmake/modules/contrib/MSC.cmake | 26 + python/tvm/contrib/msc/__init__.py | 17 + python/tvm/contrib/msc/core/__init__.py | 17 + python/tvm/contrib/msc/core/_ffi_api.py | 21 + .../contrib/msc/core/transform/__init__.py | 20 + .../tvm/contrib/msc/core/transform/pattern.py | 490 +++++++ .../contrib/msc/core/transform/transform.py | 61 + python/tvm/contrib/msc/core/utils/__init__.py | 19 + python/tvm/contrib/msc/core/utils/expr.py | 105 ++ .../msc/core/transform/layout_utils.cc | 190 +++ src/contrib/msc/core/transform/layout_utils.h | 110 ++ .../msc/core/transform/set_expr_layout.cc | 1213 +++++++++++++++++ .../msc/core/transform/set_expr_name.cc | 348 +++++ src/contrib/msc/core/utils.cc | 314 +++++ src/contrib/msc/core/utils.h | 270 ++++ .../test_msc/test_transform_set_expr_name.py | 105 ++ 18 files changed, 3331 insertions(+) create mode 100644 cmake/modules/contrib/MSC.cmake create mode 100644 python/tvm/contrib/msc/__init__.py create mode 100644 python/tvm/contrib/msc/core/__init__.py create mode 100644 python/tvm/contrib/msc/core/_ffi_api.py create mode 100644 python/tvm/contrib/msc/core/transform/__init__.py create mode 100644 python/tvm/contrib/msc/core/transform/pattern.py create mode 100644 python/tvm/contrib/msc/core/transform/transform.py create mode 100644 python/tvm/contrib/msc/core/utils/__init__.py create mode 100644 python/tvm/contrib/msc/core/utils/expr.py create mode 100644 src/contrib/msc/core/transform/layout_utils.cc create mode 100644 src/contrib/msc/core/transform/layout_utils.h create mode 100644 src/contrib/msc/core/transform/set_expr_layout.cc create mode 100644 src/contrib/msc/core/transform/set_expr_name.cc create mode 100644 src/contrib/msc/core/utils.cc create mode 100644 src/contrib/msc/core/utils.h create mode 100644 tests/python/contrib/test_msc/test_transform_set_expr_name.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 47d57d56bd73..f7c34fa22bf7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -121,6 +121,7 @@ tvm_option(USE_CLML "Build with CLML Codegen support" OFF) tvm_option(USE_CLML_GRAPH_EXECUTOR "Build with CLML graph runtime" OFF) tvm_option(USE_UMA "Build with UMA support" OFF) tvm_option(USE_VERILATOR "Build with Verilator support" OFF) +tvm_option(USE_MSC "Enable Multi-System Compiler" OFF) # include directories include_directories(${CMAKE_INCLUDE_PATH}) @@ -545,6 +546,7 @@ include(cmake/modules/contrib/TensorRT.cmake) include(cmake/modules/contrib/VitisAI.cmake) include(cmake/modules/contrib/Verilator.cmake) include(cmake/modules/contrib/UMA.cmake) +include(cmake/modules/contrib/MSC.cmake) include(cmake/modules/Git.cmake) include(cmake/modules/LibInfo.cmake) include(cmake/modules/RustExt.cmake) diff --git a/cmake/config.cmake b/cmake/config.cmake index 8a7a0f1fdd29..4990e52d634f 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -281,6 +281,9 @@ set(USE_VITIS_AI OFF) # Build Verilator codegen and runtime set(USE_VERILATOR OFF) +# Whether to use the Multi-System Compiler +set(USE_MSC OFF) + #Whether to use CLML codegen set(USE_CLML OFF) # USE_CLML_GRAPH_EXECUTOR - CLML SDK PATH or ON or OFF diff --git a/cmake/modules/contrib/MSC.cmake b/cmake/modules/contrib/MSC.cmake new file mode 100644 index 000000000000..45ce776a0864 --- /dev/null +++ b/cmake/modules/contrib/MSC.cmake @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(USE_MSC) + tvm_file_glob(GLOB_RECURSE MSC_CORE_SOURCE "src/contrib/msc/*.cc") + list(APPEND COMPILER_SRCS ${MSC_CORE_SOURCE}) + + tvm_file_glob(GLOB_RECURSE MSC_RUNTIME_SOURCE "src/runtime/contrib/msc/*.cc") + list(APPEND RUNTIME_SRCS ${MSC_RUNTIME_SOURCE}) + + message(STATUS "Build with MSC support...") +endif() diff --git a/python/tvm/contrib/msc/__init__.py b/python/tvm/contrib/msc/__init__.py new file mode 100644 index 000000000000..a2813b4a2dca --- /dev/null +++ b/python/tvm/contrib/msc/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc""" diff --git a/python/tvm/contrib/msc/core/__init__.py b/python/tvm/contrib/msc/core/__init__.py new file mode 100644 index 000000000000..6d1a7c68c86d --- /dev/null +++ b/python/tvm/contrib/msc/core/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core""" diff --git a/python/tvm/contrib/msc/core/_ffi_api.py b/python/tvm/contrib/msc/core/_ffi_api.py new file mode 100644 index 000000000000..c0b0e21267ea --- /dev/null +++ b/python/tvm/contrib/msc/core/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core._ffi_api""" + +import tvm._ffi + +tvm._ffi._init_api("msc.core", __name__) diff --git a/python/tvm/contrib/msc/core/transform/__init__.py b/python/tvm/contrib/msc/core/transform/__init__.py new file mode 100644 index 000000000000..ec7459780359 --- /dev/null +++ b/python/tvm/contrib/msc/core/transform/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.transform""" + +from .pattern import * +from .transform import * diff --git a/python/tvm/contrib/msc/core/transform/pattern.py b/python/tvm/contrib/msc/core/transform/pattern.py new file mode 100644 index 000000000000..500870509791 --- /dev/null +++ b/python/tvm/contrib/msc/core/transform/pattern.py @@ -0,0 +1,490 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""tvm.contrib.msc.core.transform.pattern""" + +import tvm +from tvm.relax.dpl import pattern as relax_pattern +from tvm.relay import dataflow_pattern as relay_pattern + +from tvm.relax.transform import PatternCheckContext +from tvm.relax.backend.pattern_registry import register_patterns +from tvm.relay.op.contrib.register import register_pattern_table + + +def make_relax_conv_bias_pattern(op_name): + """A simple utility to create patterns for an operation fused with bias. + + Parameters + ---------- + op_name: str + The name of a Relax op, such as "relax.nn.conv2d" + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a conv_bias operation + """ + + data = relax_pattern.wildcard() + weight = relax_pattern.is_const() + conv = relax_pattern.is_op(op_name)(data, weight) + bias = relax_pattern.is_const() + shape = relax_pattern.wildcard() + reshape = relax_pattern.is_op("relax.reshape")(bias, shape) + out = relax_pattern.is_op("relax.add")(conv, reshape) + annotations = {"bias": bias, "reshape": reshape} + return out, annotations + + +def _check_relax_conv_bias(context: PatternCheckContext) -> bool: + """Check if conv_bias fuse pattern is correct.""" + bias = context.annotated_expr["bias"] + reshape = context.annotated_expr["reshape"] + non_one_dims = len([i for i in reshape.struct_info.shape.values if i > 1]) + return non_one_dims <= 1 and bias.struct_info.ndim == 1 + + +def make_relax_linear_pattern(): + """A simple utility to create patterns for linear. + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a linear operation + """ + + data = relax_pattern.wildcard() + weight = relax_pattern.is_const() + permute = relax_pattern.is_op("relax.permute_dims")(weight) + out = relax_pattern.is_op("relax.matmul")(data, permute) + annotations = {"weight": weight, "permute": permute} + return out, annotations + + +def _check_relax_linear(context: PatternCheckContext) -> bool: + """Check if linear pattern is correct.""" + weight = context.annotated_expr["weight"] + permute = context.annotated_expr["permute"] + return weight.struct_info.ndim == 2 and not permute.attrs["axes"] + + +def make_relax_linear_bias_pattern(): + """A simple utility to create patterns for linear with bias. + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a linear_bias operation + """ + + linear, annotations = make_relax_linear_pattern() + bias = relax_pattern.is_const() + out = relax_pattern.is_op("relax.add")(linear, bias) + annotations.update({"bias": bias, "out": out}) + return out, annotations + + +def _check_relax_linear_bias(context: PatternCheckContext) -> bool: + """Check if linear_bias pattern is correct.""" + if not _check_relax_linear(context): + return False + bias = context.annotated_expr["bias"] + return bias.struct_info.ndim == 1 + + +def make_relax_embedding_pattern(): + """A simple utility to create patterns for embedding. + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a embedding operation + """ + + weight = relax_pattern.is_const() + data = relax_pattern.wildcard() + astype = relax_pattern.is_op("relax.astype")(data) + out = relax_pattern.is_op("relax.take")(weight, astype) + annotations = {"weight": weight, "astype": astype} + return out, annotations + + +def _check_relax_embedding(context: PatternCheckContext) -> bool: + """Check if 1d embedding pattern is correct.""" + weight = context.annotated_expr["weight"] + astype = context.annotated_expr["astype"] + return ( + astype.attrs["dtype"] == "int32" + and weight.struct_info.ndim == 2 + and weight.struct_info.dtype == "float32" + ) + + +def make_relax_reshape_embedding_pattern(): + """A simple utility to create patterns for reshaped embedding. + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a reshaped rembedding operation + """ + + weight = relax_pattern.is_const() + data = relax_pattern.wildcard() + astype = relax_pattern.is_op("relax.astype")(data) + reduce_shape = relax_pattern.wildcard() + reduce_in = relax_pattern.is_op("relax.reshape")(astype, reduce_shape) + take = relax_pattern.is_op("relax.take")(weight, reduce_in) + expand_shape = relax_pattern.wildcard() + out = relax_pattern.is_op("relax.reshape")(take, expand_shape) + annotations = {"weight": weight, "astype": astype, "reduce_in": reduce_in} + return out, annotations + + +def _check_relax_reshape_embedding(context: PatternCheckContext) -> bool: + """Check if reshape embedding pattern is correct.""" + weight = context.annotated_expr["weight"] + if weight.struct_info.ndim != 2 or weight.struct_info.dtype != "float32": + return False + astype = context.annotated_expr["astype"] + reduce_in = context.annotated_expr["reduce_in"] + if astype.attrs["dtype"] != "int32" or reduce_in.struct_info.ndim != 1: + return False + return True + + +def make_relax_attention_pattern(): + """A simple utility to create patterns for attention. + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a attention operation + """ + + weight_q = relax_pattern.wildcard() + weight_k = relax_pattern.wildcard() + weight_v = relax_pattern.wildcard() + q_trans = relax_pattern.is_op("relax.permute_dims")(weight_q) + k_trans = relax_pattern.is_op("relax.permute_dims")(weight_k) + v_trans = relax_pattern.is_op("relax.permute_dims")(weight_v) + out = relax_pattern.is_op("relax.nn.attention")(q_trans, k_trans, v_trans) + annotations = {"q_trans": q_trans, "k_trans": k_trans, "v_trans": v_trans} + return out, annotations + + +def _check_relax_attention(context: PatternCheckContext) -> bool: + """Check if attention pattern is correct.""" + return True + + +def make_relax_mask_attention_pattern(): + """A simple utility to create patterns for mask_attention. + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a mask_attention operation + """ + + weight_q = relax_pattern.wildcard() + weight_k = relax_pattern.wildcard() + weight_v = relax_pattern.wildcard() + mask = relax_pattern.wildcard() + q_trans = relax_pattern.is_op("relax.permute_dims")(weight_q) + k_trans = relax_pattern.is_op("relax.permute_dims")(weight_k) + v_trans = relax_pattern.is_op("relax.permute_dims")(weight_v) + out = relax_pattern.is_op("relax.nn.attention_bias")(q_trans, k_trans, v_trans, mask) + annotations = {"q_trans": q_trans, "k_trans": k_trans, "v_trans": v_trans} + return out, annotations + + +def _check_relax_mask_attention(context: PatternCheckContext) -> bool: + """Check if mask_attention pattern is correct.""" + return True + + +# TODO(tong.meng): support patterns after optimize +register_patterns( + [ + ( + "msc.conv1d_bias", + *make_relax_conv_bias_pattern( + "relax.nn.conv1d", + ), + _check_relax_conv_bias, + ), + ( + "msc.conv2d_bias", + *make_relax_conv_bias_pattern( + "relax.nn.conv2d", + ), + _check_relax_conv_bias, + ), + ( + "msc.linear", + *make_relax_linear_pattern(), + _check_relax_linear, + ), + ( + "msc.linear_bias", + *make_relax_linear_bias_pattern(), + _check_relax_linear_bias, + ), + ( + "msc.embedding", + *make_relax_embedding_pattern(), + _check_relax_embedding, + ), + ( + "msc.embedding", + *make_relax_reshape_embedding_pattern(), + _check_relax_reshape_embedding, + ), + ( + "msc.attention", + *make_relax_attention_pattern(), + _check_relax_attention, + ), + ( + "msc.attention", + *make_relax_mask_attention_pattern(), + _check_relax_mask_attention, + ), + ] +) + + +# TODO(tong.meng): support patterns after optimize +@register_pattern_table("msc") +def pattern_table(): + """Returns list of triples describing the name, dataflow pattern and predicate for all + the MSC-supported operators.""" + + def make_relay_conv_bias_pattern(op_name, optimized=False): + """A simple utility to create patterns for an operation fused with bias. + + Parameters + ---------- + op_name: str + The name of a Relay op, such as "relay.nn.conv2d" + optimized: bool + Whether the relay is optimized + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a conv_bias operation + """ + + data = relay_pattern.wildcard() + weight = relay_pattern.is_constant() + bias = relay_pattern.is_constant() + conv = relay_pattern.is_op(op_name)(data, weight) + if optimized: + out = relay_pattern.is_op("add")(conv, bias) + else: + out = relay_pattern.is_op("nn.bias_add")(conv, bias) + return out + + def _check_relay_conv_bias(call: tvm.relay.Expr) -> bool: + """Check if conv_bias fuse pattern is correct.""" + + if call.op.name == "nn.bias_add": + bias = call.args[1] + return len(bias.checked_type.shape) == 1 + if call.op.name == "add": + return True + return False + + def make_relay_linear_pattern(optimized=False): + """A simple utility to create patterns for linear. + + Parameters + ---------- + optimized: bool + Whether the relay is optimized + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a linear operation + """ + + if optimized: + data = relay_pattern.wildcard() + weight = relay_pattern.is_constant() + broadcast_data = relay_pattern.is_op("broadcast_to")(data) + reshape_data = relay_pattern.is_op("reshape")(broadcast_data) + batch_matmul = relay_pattern.is_op("nn.batch_matmul")(reshape_data, weight) + reshape_out = relay_pattern.is_op("reshape")(batch_matmul) + return relay_pattern.is_op("squeeze")(reshape_out) + data = relay_pattern.wildcard() + weight = relay_pattern.is_constant() + trans_weight = relay_pattern.is_op("transpose")(weight) + broadcast_data = relay_pattern.is_op("broadcast_to")(data) + broadcast_weight = relay_pattern.is_op("broadcast_to")(trans_weight) + reshape_data = relay_pattern.is_op("reshape")(broadcast_data) + reshape_weight = relay_pattern.is_op("reshape")(broadcast_weight) + batch_matmul = relay_pattern.is_op("nn.batch_matmul")(reshape_data, reshape_weight) + reshape_out = relay_pattern.is_op("reshape")(batch_matmul) + return relay_pattern.is_op("squeeze")(reshape_out) + + def _check_relay_linear(call: tvm.relay.Expr) -> bool: + """Check if linear pattern is correct.""" + return True + + def make_relay_linear_bias_pattern(optimized=False): + """A simple utility to create patterns for linear_bias. + + Parameters + ---------- + optimized: bool + Whether the relay is optimized + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a linear_bias operation + """ + + bias = relay_pattern.is_constant() + linear = make_relay_linear_pattern(optimized) + if optimized: + out = relay_pattern.is_op("add")(linear, bias) + else: + out = relay_pattern.is_op("nn.bias_add")(linear, bias) + return out + + def _check_relay_linear_bias(call: tvm.relay.Expr) -> bool: + """Check if linear_bias pattern is correct.""" + return True + + def make_relay_matmul_pattern(dim=2, optimized=False): + """A simple utility to create patterns for matmul. + + Parameters + ---------- + optimized: bool + Whether the relay is optimized + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a matmul operation + """ + + if dim == 2: + a = relay_pattern.wildcard() + b = relay_pattern.wildcard() + trans_b = relay_pattern.is_op("transpose")(b) + dense = relay_pattern.is_op("nn.dense")(a, trans_b) + return dense | relay_pattern.is_op("squeeze")(dense) + elif dim == 3: + a = relay_pattern.wildcard() + b = relay_pattern.wildcard() + broadcast_a = relay_pattern.is_op("broadcast_to")(a) + broadcast_b = relay_pattern.is_op("broadcast_to")(b) + reshape_a = relay_pattern.is_op("reshape")(broadcast_a) + reshape_b = relay_pattern.is_op("reshape")(broadcast_b) + batch_matmul = relay_pattern.is_op("nn.batch_matmul")(reshape_a, reshape_b) + reshape_out = relay_pattern.is_op("reshape")(batch_matmul) + return relay_pattern.is_op("squeeze")(reshape_out) + else: + raise Exception("matmul pattern only support dim 2 and 3") + + def _check_relay_matmul(call: tvm.relay.Expr) -> bool: + """Check if matmul pattern is correct.""" + last_call = call.args[0] if call.op.name == "squeeze" else call + if last_call.op.name == "nn.dense": + trans_b = last_call.args[1] + b = trans_b.args[0] + if len(b.checked_type.shape) != 2: + return False + return trans_b.attrs["axes"] is None or list(trans_b.attrs["axes"]) == [1, 0] + return True + + def make_relay_embedding_pattern(optimized=False): + """A simple utility to create patterns for 1d embedding. + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a embedding operation + """ + + weight = relay_pattern.is_constant() + data = relay_pattern.wildcard() + astype = relay_pattern.is_op("cast")(data) + return relay_pattern.is_op("take")(weight, astype) + + def _check_relay_embedding(call) -> bool: + """Check if embedding pattern is correct.""" + + weight = call.args[0] + cast = call.args[1] + return ( + cast.attrs["dtype"] == "int32" + and len(weight.checked_type.shape) == 2 + and weight.checked_type.dtype == "float32" + ) + + def make_relay_gelu_pattern(optimized=False): + """A simple utility to create patterns for gelu. + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a gelu operation. + """ + + data = relay_pattern.wildcard() + factor_1 = relay_pattern.is_constant() + mul_1 = relay_pattern.is_op("multiply")(data, factor_1) + erf = relay_pattern.is_op("erf")(mul_1) + factor_2 = relay_pattern.is_constant() + mul_2 = relay_pattern.is_op("multiply")(erf, factor_2) + factor_3 = relay_pattern.is_constant() + add = relay_pattern.is_op("add")(factor_3, mul_2) + return relay_pattern.is_op("multiply")(data, add) + + def _check_relay_gelu(call) -> bool: + """Check if gelu pattern is correct.""" + return True + + return [ + ("msc.conv1d_bias", make_relay_conv_bias_pattern("nn.conv1d"), _check_relay_conv_bias), + ( + "msc.conv1d_bias", + make_relay_conv_bias_pattern("nn.conv1d", True), + _check_relay_conv_bias, + ), + ("msc.conv2d_bias", make_relay_conv_bias_pattern("nn.conv2d"), _check_relay_conv_bias), + ( + "msc.conv2d_bias", + make_relay_conv_bias_pattern("nn.conv2d", True), + _check_relay_conv_bias, + ), + ("msc.linear_bias", make_relay_linear_bias_pattern(), _check_relay_linear_bias), + ("msc.linear", make_relay_linear_pattern(), _check_relay_linear), + ("msc.linear", make_relay_linear_pattern(True), _check_relay_linear), + ("msc.matmul", make_relay_matmul_pattern(dim=2), _check_relay_matmul), + ("msc.matmul", make_relay_matmul_pattern(dim=3), _check_relay_matmul), + ("msc.embedding", make_relay_embedding_pattern(), _check_relay_embedding), + ("msc.gelu", make_relay_gelu_pattern(), _check_relay_gelu), + ] diff --git a/python/tvm/contrib/msc/core/transform/transform.py b/python/tvm/contrib/msc/core/transform/transform.py new file mode 100644 index 000000000000..355922d6def2 --- /dev/null +++ b/python/tvm/contrib/msc/core/transform/transform.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""tvm.contrib.msc.core.transform.transform""" + +import tvm +from tvm.relax.transform import _ffi_api as relax_api +from tvm.relay.transform import _ffi_api as relay_api + + +def SetExprName(as_relax=True, entry_name="main") -> tvm.ir.transform.Pass: + """Set name for the call and constant in IRModule. + + Parameters + ---------- + as_relax: bool + Whether set names for relax, otherwise for relay. + entry_name: str + The entry name + + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + + if as_relax: + return relax_api.SetRelaxExprName(entry_name) # type: ignore + return relay_api.SetRelayExprName(entry_name) # type: ignore + + +def SetExprLayout(allow_missing=True, entry_name="main") -> tvm.ir.transform.Pass: + """Set layout for the var and constant in IRModule. + + Parameters + ---------- + allow_missing: bool + Whether allow missing layouts. + entry_name: str + The entry name + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + + return relax_api.SetExprLayout(allow_missing, entry_name) # type: ignore diff --git a/python/tvm/contrib/msc/core/utils/__init__.py b/python/tvm/contrib/msc/core/utils/__init__.py new file mode 100644 index 000000000000..65f9e1b32624 --- /dev/null +++ b/python/tvm/contrib/msc/core/utils/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.utils""" + +from .expr import * diff --git a/python/tvm/contrib/msc/core/utils/expr.py b/python/tvm/contrib/msc/core/utils/expr.py new file mode 100644 index 000000000000..ad459e78325d --- /dev/null +++ b/python/tvm/contrib/msc/core/utils/expr.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.utils.expr""" + +import tvm +from tvm import relax +from tvm.relax import PyExprVisitor +from tvm.contrib.msc.core import _ffi_api + + +def get_span_attrs(mod: tvm.IRModule) -> dict: + """Extract the span attributes from relax.Function. + + Parameters + ---------- + mod: IRModule + The IRModule of relax. + + Returns + ------- + attrs: dict + """ + + @relax.expr_functor.visitor + class SpanVisitor(PyExprVisitor): + """Visitor for get attributes in span""" + + def extract(self, expr: relax.Expr) -> dict: + self._span_info = {} + if isinstance(expr, relax.Expr): + self.visit_expr(expr) + elif isinstance(expr, relax.BindingBlock): + self.visit_binding_block(expr) + return self._span_info + + def _update_attrs(self, expr: relax.Expr, name: str = "") -> None: + if not expr.span: + return + name = name or _ffi_api.SpanGetAttr(expr.span, "name") + if not name: + return + self._span_info[name] = _ffi_api.SpanGetAttrs(expr.span) + + def visit_var_binding_(self, binding: relax.VarBinding) -> None: + super().visit_var_binding_(binding) + self._update_attrs(binding.value, binding.var.name_hint) + + def visit_constant_(self, op: relax.Constant) -> None: + super().visit_constant_(op) + self._update_attrs(op) + + def visit_var_(self, op: relax.Var) -> None: + super().visit_var_(op) + self._update_attrs(op, op.name_hint) + + return {v.name_hint: SpanVisitor().extract(mod[v]) for v in mod.functions} + + +def msc_script(mod: tvm.IRModule, script: str = "") -> str: + """Add span attrs after lines. + + Parameters + ---------- + mod: IRModule + The IRModule of relax. + script: string + The script to be replaced + + Returns + ------- + script: string + The replaced script + """ + + script = script or str(mod) + attrs = get_span_attrs(mod) + cur_attr, lines = {}, [] + for line in script.split("\n"): + if line.strip().startswith("def "): + func_name = line.strip().split("def ")[1].split("(")[0] + cur_attr = attrs.get(func_name, {}) + if ": " in line: + v_name = line.strip().split(": ")[0] + if v_name in cur_attr: + line += ( + " # " + + ", ".join(["{}={}".format(k, v) for k, v in cur_attr[v_name].items()]) + + " #" + ) + lines.append(line) + return "\n".join(lines) diff --git a/src/contrib/msc/core/transform/layout_utils.cc b/src/contrib/msc/core/transform/layout_utils.cc new file mode 100644 index 000000000000..ffc631c6d033 --- /dev/null +++ b/src/contrib/msc/core/transform/layout_utils.cc @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/core/transform/layout_utils.cc + */ +#include "layout_utils.h" + +#include +#include + +namespace tvm { +namespace contrib { +namespace msc { + +bool LayoutUtils::LayoutInfered(const Expr& expr) { + const String& layout = SpanUtils::GetAttr(expr->span, "layout"); + return layout.size() > 0; +} + +bool LayoutUtils::SetLayout(const Expr& expr, const NLayout& layout) { + const String& saved_layout = SpanUtils::GetAttr(expr->span, "layout"); + const auto& sinfo = GetStructInfo(expr); + if (sinfo.as() || sinfo.as()) { + ICHECK(layout.IsLeaf()) << "Expr has tensor struct, but find nested layout " << expr; + const auto& l_layout = layout.LeafValue()->layout; + if (!l_layout.defined()) { + return false; + } + if (saved_layout == l_layout.name()) { + return false; + } + expr->span = SpanUtils::SetAttr(expr->span, "layout", l_layout.name()); + } else if (sinfo.as()) { + ICHECK(!layout.IsLeaf()) << "Expr has tupple struct, but find non-nested layout " << expr; + String layout_str; + Array nested_layouts = layout.NestedArray(); + for (size_t i = 0; i < nested_layouts.size(); i++) { + ICHECK(nested_layouts[i].IsLeaf()) + << "Expr input[" << i << "] has tensor struct, but find nested layout " << expr; + const auto& l_layout = nested_layouts[i].LeafValue()->layout; + if (!l_layout.defined()) { + return false; + } + layout_str = layout_str + l_layout.name() + (i < nested_layouts.size() - 1 ? "," : ""); + } + if (saved_layout == layout_str) { + return false; + } + expr->span = SpanUtils::SetAttr(expr->span, "layout", layout_str); + } + return true; +} + +const NLayout LayoutUtils::GetNLayout(const Expr& expr) { + if (!LayoutInfered(expr)) { + return LayoutDecision(""); + } + auto sinfo = GetStructInfo(expr); + if (sinfo.as()) { + return LayoutDecision(SpanUtils::GetAttr(expr->span, "layout")); + } + if (sinfo.as()) { + String layout_str = SpanUtils::GetAttr(expr->span, "layout"); + std::vector output_layout; + for (const auto& l : StringUtils::Split(layout_str, ",")) { + output_layout.push_back(LayoutDecision(l)); + } + return NLayout(output_layout); + } + return LayoutDecision(""); +} + +const LayoutDecision LayoutUtils::GetLayoutDecision(const Expr& expr) { + NLayout nlayout = GetNLayout(expr); + ICHECK(nlayout.IsLeaf()) << "Cannot get layout for " << expr; + return nlayout.LeafValue(); +} + +bool LayoutUtils::HasUnknownDimTensor(const NLayout& nlayout) { + bool find = false; + auto fvisit = [&](const LayoutDecision& layout) { + find = find | (NLayoutEqual()(layout, LayoutDecision::InitUnknownDim())); + }; + ForEachLeaf(nlayout, fvisit); + return find; +} + +bool LayoutUtils::HasUnknownDimTensor(const Array& args) { + for (const auto& arg : args) { + if (IsNestedTensor(arg)) { + if (HasUnknownDimTensor(GetNLayout(arg))) { + return true; + } + } + } + return false; +} + +const LayoutDecision LayoutUtils::ExpandLayout(const LayoutDecision& src_layout, + const std::vector& expand_axes) { + if (!src_layout->layout.defined()) { + return src_layout; + } + std::string new_layout = src_layout.name(); + ICHECK_EQ(new_layout.size(), src_layout->layout.ndim()) + << "Only support normal layout, get " << src_layout->layout; + std::vector priority_dims{"N", "C", "H", "W", "D", "G", "T"}; + size_t left_size = expand_axes.size(); + for (const auto& a : expand_axes) { + std::string target = "U"; + if (new_layout.find("H") && !new_layout.find("W")) { + target = "W"; + } else if (new_layout.find("W") && !new_layout.find("H")) { + target = "H"; + } else if (left_size == 1 && new_layout.find("C") && !new_layout.find("D")) { + target = "D"; + } else if (left_size == 1 && new_layout.find("D") && !new_layout.find("C")) { + target = "C"; + } else { + for (const auto& p : priority_dims) { + int pos = new_layout.find(p); + if (pos < 0) { + target = p; + break; + } + } + } + new_layout = new_layout.insert(a, target); + left_size--; + } + return LayoutDecision(new_layout); +} + +const LayoutDecision LayoutUtils::ReduceLayout(const LayoutDecision& src_layout, + const std::vector& reduce_axes) { + if (!src_layout->layout.defined()) { + return src_layout; + } + std::set reduce_axes_set; + for (const auto& a : reduce_axes) { + reduce_axes_set.insert(a); + } + std::string new_layout = ""; + for (size_t i = 0; i < src_layout->layout.ndim(); i++) { + if (reduce_axes_set.count(i)) { + continue; + } + new_layout += src_layout->layout[i].name(); + } + return LayoutDecision(new_layout); +} + +const LayoutDecision LayoutUtils::PermuteLayout(const LayoutDecision& src_layout, + const Array& axes) { + String layout_str; + for (const auto& a : axes) { + layout_str = layout_str + src_layout->layout[a->value].name(); + } + return LayoutDecision(layout_str); +} + +const LayoutDecision LayoutUtils::PermuteLayout(const LayoutDecision& src_layout, + const std::vector& axes) { + String layout_str; + for (const auto& a : axes) { + layout_str = layout_str + src_layout->layout[a].name(); + } + return LayoutDecision(layout_str); +} + +} // namespace msc +} // namespace contrib +} // namespace tvm diff --git a/src/contrib/msc/core/transform/layout_utils.h b/src/contrib/msc/core/transform/layout_utils.h new file mode 100644 index 000000000000..b9de832838c5 --- /dev/null +++ b/src/contrib/msc/core/transform/layout_utils.h @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/core/transform/layout_utils.h + * \brief Common utilities for layout. + */ +#ifndef TVM_CONTRIB_MSC_CORE_TRANSFORM_LAYOUT_UTILS_H_ +#define TVM_CONTRIB_MSC_CORE_TRANSFORM_LAYOUT_UTILS_H_ + +#include +#include + +#include + +#include "../../../../relax/transform/infer_layout_utils.h" +#include "../../../../relax/transform/utils.h" +#include "../utils.h" + +namespace tvm { +namespace contrib { +namespace msc { + +using Expr = tvm::RelayExpr; +using namespace tvm::relax; + +/*! + * \brief Utils for Layout. + */ +class LayoutUtils { + public: + /*! + * \brief Check if the layout is infered. + * \return Whether the layout is infered. + */ + TVM_DLL static bool LayoutInfered(const Expr& expr); + + /*! + * \brief Set the layout to span + * \return Whether the layout is setted. + */ + TVM_DLL static bool SetLayout(const Expr& expr, const NLayout& layout); + + /*! + * \brief Get the layout from span + * \return The NLayout. + */ + TVM_DLL static const NLayout GetNLayout(const Expr& expr); + + /*! + * \brief Get the layout desion from span + * \return The LayoutDecision. + */ + TVM_DLL static const LayoutDecision GetLayoutDecision(const Expr& expr); + + /*! + * \brief Check if the layout has unknown dim tensor. + * \return Whether the layout has unknown dim tensor. + */ + TVM_DLL static bool HasUnknownDimTensor(const NLayout& nlayout); + + /*! + * \brief Check if the args has unknown dim tensor. + * \return Whether the args has unknown dim tensor. + */ + TVM_DLL static bool HasUnknownDimTensor(const Array& args); + + /*! + * \brief Insert axes to the Layout + * \return The new layout. + */ + TVM_DLL static const LayoutDecision ExpandLayout(const LayoutDecision& src_layout, + const std::vector& expand_axes); + + /*! + * \brief Delete axes from the Layout + * \return The new layout. + */ + TVM_DLL static const LayoutDecision ReduceLayout(const LayoutDecision& src_layout, + const std::vector& reduce_axes); + /*! + * \brief Permute axes from the Layout + * \return The new layout. + */ + TVM_DLL static const LayoutDecision PermuteLayout(const LayoutDecision& src_layout, + const Array& axes); + TVM_DLL static const LayoutDecision PermuteLayout(const LayoutDecision& src_layout, + const std::vector& axes); +}; + +} // namespace msc +} // namespace contrib +} // namespace tvm +#endif // TVM_CONTRIB_MSC_CORE_TRANSFORM_LAYOUT_UTILS_H_ diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc b/src/contrib/msc/core/transform/set_expr_layout.cc new file mode 100644 index 000000000000..981829b56809 --- /dev/null +++ b/src/contrib/msc/core/transform/set_expr_layout.cc @@ -0,0 +1,1213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/core/transform/set_expr_layout.cc + * \brief Pass for setting layout for expr and constant. + */ + +#include +#include +#include + +#include "../utils.h" +#include "layout_utils.h" + +namespace tvm { +namespace relax { + +using namespace tvm::contrib::msc; + +NLayout InferNLayout(const Expr& expr, const VarLayoutMap& var_layout_map) { + if (expr.as() && var_layout_map.count(Downcast(expr))) { + return GetNLayout(var_layout_map, expr); + } + return LayoutUtils::GetNLayout(expr); +} + +LayoutDecision InferLayoutDecision(const Expr& expr, const VarLayoutMap& var_layout_map) { + const auto& nlayout = InferNLayout(expr, var_layout_map); + ICHECK(nlayout.IsLeaf()) << "Cannot get layout for " << expr; + return nlayout.LeafValue(); +} + +LayoutDecision InferLayoutDecisionAt(const Expr& expr, const VarLayoutMap& var_layout_map, + size_t index = 0) { + const auto& nlayouts = InferNLayout(expr, var_layout_map); + const auto& nlayout = nlayouts.NestedArray()[0]; + ICHECK(nlayout.IsLeaf()) << "Cannot get output layout for " << expr; + return nlayout.LeafValue(); +} + +std::tuple AccumulateMatch(const std::vector& in_shape, + const std::vector& out_shape, size_t in_start, + size_t out_start) { + // find input position in_pos and output position out_pos + // cumsum(in_shape[in_start:in_ops])==cumsum(out_shape[out_start:out_pos]) + int64_t in_pos = -1; + int64_t out_pos = -1; + int64_t in_accumulate = 1; + int64_t out_accumulate = 1; + for (size_t i = in_start; i < in_shape.size(); i++) { + in_accumulate *= in_shape[i]; + out_accumulate = 1; + for (size_t j = out_start; j < out_shape.size(); j++) { + out_accumulate *= out_shape[j]; + if (in_accumulate == out_accumulate) { + in_pos = i; + out_pos = j; + break; + } else if (out_accumulate > in_accumulate) { + break; + } + } + if (in_pos >= 0) { + break; + } + } + // append tailed 1s + if (in_pos >= 0) { + while (in_pos < in_shape.size() - 1 && in_shape[in_pos + 1] == 1) { + in_pos++; + } + while (out_pos < out_shape.size() - 1 && out_shape[out_pos + 1] == 1) { + out_pos++; + } + } + return std::make_tuple(in_pos, out_pos); +} + +std::vector InferReduceAxes(const Array& input_shape, + const Array& output_shape) { + std::vector reduce_axes, out_axes; + std::vector in_shape, out_shape; + for (const auto& s : input_shape) { + in_shape.push_back(Downcast(s)->value); + } + for (const auto& s : output_shape) { + out_shape.push_back(Downcast(s)->value); + } + size_t start = 0; + while (start < in_shape.size() && out_axes.size() < out_shape.size()) { + if (in_shape[start] == out_shape[out_axes.size()]) { + out_axes.push_back(start); + start++; + } else { + int64_t in_pos, out_pos; + size_t out_start = out_axes.size(); + std::tie(in_pos, out_pos) = AccumulateMatch(in_shape, out_shape, start, out_start); + if (in_pos == -1) { + return std::vector(); + } + for (size_t i = out_start; i < out_pos + 1; i++) { + out_axes.push_back(i + 1); + } + start = in_pos + 1; + } + } + if (out_axes.size() != out_shape.size()) { + return std::vector(); + } + std::set out_axes_set; + for (const auto& a : out_axes) { + out_axes_set.insert(a); + } + for (size_t i = 0; i < in_shape.size(); i++) { + if (!out_axes_set.count(i)) { + reduce_axes.push_back(i); + } + } + return reduce_axes; +} + +std::vector InferExpandAxes(const Array& input_shape, + const Array& output_shape) { + std::vector expand_axes; + std::vector in_shape, out_shape; + for (const auto& s : input_shape) { + in_shape.push_back(Downcast(s)->value); + } + for (const auto& s : output_shape) { + out_shape.push_back(Downcast(s)->value); + } + size_t start = 0; + while (start < in_shape.size() && expand_axes.size() + in_shape.size() < out_shape.size()) { + if (in_shape[start] == out_shape[start + expand_axes.size()]) { + start++; + } else { + int64_t in_pos, out_pos; + size_t out_start = start + expand_axes.size(); + std::tie(in_pos, out_pos) = AccumulateMatch(in_shape, out_shape, start, out_start); + if (in_pos == -1) { + return std::vector(); + } + size_t expand_size = out_pos - in_pos - expand_axes.size(); + for (size_t i = 0; i < expand_size; i++) { + expand_axes.push_back(out_start + i); + } + start = in_pos + 1; + } + } + if (expand_axes.size() + in_shape.size() != out_shape.size()) { + return std::vector(); + } + return expand_axes; +} + +// Forward and Backward infer +InferLayoutOutput MSCInferLayoutConv(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + LayoutDecision data_layout, kernel_layout, out_layout; + const String& op_name = Downcast(call->op)->name; + if (op_name == "relax.nn.conv1d") { + const auto* attrs = call->attrs.as(); + data_layout = LayoutDecision(attrs->data_layout); + kernel_layout = LayoutDecision(attrs->kernel_layout); + out_layout = LayoutDecision(attrs->out_layout); + } else if (op_name == "relax.nn.conv2d") { + const auto* attrs = call->attrs.as(); + data_layout = LayoutDecision(attrs->data_layout); + kernel_layout = LayoutDecision(attrs->kernel_layout); + out_layout = LayoutDecision(attrs->out_layout); + } + return InferLayoutOutput({data_layout, kernel_layout}, {out_layout}, Attrs()); +} + +InferLayoutOutput MSCInferLayoutPool2d(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + LayoutDecision layout, out_layout; + const String& op_name = Downcast(call->op)->name; + if (op_name == "relax.nn.adaptive_avg_pool2d") { + const auto* attrs = call->attrs.as(); + layout = LayoutDecision(attrs->layout); + out_layout = LayoutDecision(attrs->out_layout); + } else { + const auto* attrs = call->attrs.as(); + layout = LayoutDecision(attrs->layout); + out_layout = LayoutDecision(attrs->out_layout); + } + return InferLayoutOutput({layout}, {out_layout}, Attrs()); +} + +// Forward Infer +InferLayoutOutput ForwardInferLayoutCommon(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + Array input_layouts; + LayoutDecision layout_hint; + for (const auto& arg : call->args) { + const auto& in_layout = InferLayoutDecision(arg, var_layout_map); + if (in_layout->layout.defined()) { + layout_hint = in_layout; + } + input_layouts.push_back(in_layout); + } + if (!layout_hint.defined()) { + return InferLayoutOutput(); + } + std::vector output_layouts; + const auto& sinfo = GetStructInfo(call); + if (sinfo.as()) { + output_layouts.push_back(layout_hint); + } else if (const auto* tuple_sinfo = sinfo.as()) { + for (size_t i = 0; i < tuple_sinfo->fields.size(); i++) { + output_layouts.push_back(layout_hint); + } + } else { + return InferLayoutOutput(); + } + return InferLayoutOutput(input_layouts, {output_layouts}, Attrs()); +} + +InferLayoutOutput ForwardInferLayoutBinary(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + const auto& output = ForwardInferLayoutCommon(call, desired_layouts, var_layout_map); + if (!output.defined()) { + return output; + } + std::vector input_layouts; + for (size_t i = 0; i < call->args.size(); i++) { + const auto& sinfo = GetStructInfo(call->args[i]); + if (const auto* t_info = sinfo.as()) { + if (t_info->ndim == 0) { + input_layouts.push_back(LayoutDecision("")); + } else { + input_layouts.push_back(output->input_layouts[i]); + } + } else { + LOG(FATAL) << "Binary input should be tensor, get " << sinfo->GetTypeKey(); + } + } + return InferLayoutOutput(input_layouts, output->output_layouts, Attrs()); +} + +InferLayoutOutput ForwardInferLayoutInplace(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + return ForwardInferLayoutCommon(call, desired_layouts, var_layout_map); +} + +InferLayoutOutput ForwardInferLayoutBatchNorm(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + Array empty; + const auto& input_shape = + Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + if (input_shape.size() == 0) { + return InferLayoutOutput(); + } + LayoutDecision in_layout = InferLayoutDecision(call->args[0], var_layout_map); + if (!in_layout->layout.defined()) { + if (input_shape.size() == 4) { + in_layout = LayoutDecision("NCHW"); + } else if (input_shape.size() == 3) { + in_layout = LayoutDecision("NCD"); + } + } + LayoutDecision g_layout = LayoutDecision("O"); + return InferLayoutOutput({in_layout, g_layout, g_layout, g_layout, g_layout}, + {{in_layout, g_layout, g_layout}}, Attrs()); +} + +InferLayoutOutput ForkwardInferLayoutExpandDims(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + LayoutDecision input_layout = InferLayoutDecision(call->args[0], var_layout_map); + if (!input_layout->layout.defined()) { + return InferLayoutOutput(); + } + Array empty; + const auto& input_shape = + Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + if (input_shape.size() == 0) { + return InferLayoutOutput(); + } + const auto* attrs = call->attrs.as(); + std::vector expand_axes; + for (const auto& s : attrs->axis) { + expand_axes.push_back(CommonUtils::GetIndex(s->value, input_shape.size())); + } + LayoutDecision output_layout = LayoutUtils::ExpandLayout(input_layout, expand_axes); + return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); +} + +InferLayoutOutput ForwardInferLayoutNormalize(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + Array empty; + const auto& input_shape = + Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + if (input_shape.size() == 0) { + return InferLayoutOutput(); + } + LayoutDecision in_layout = InferLayoutDecision(call->args[0], var_layout_map); + if (!in_layout->layout.defined()) { + if (input_shape.size() == 4) { + in_layout = LayoutDecision("NCHW"); + } else if (input_shape.size() == 3) { + in_layout = LayoutDecision("NCD"); + } + } + LayoutDecision g_layout = LayoutDecision("O"); + return InferLayoutOutput({in_layout, g_layout, g_layout}, {in_layout}, Attrs()); +} + +InferLayoutOutput ForwardInferLayoutMatmul(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + Array empty; + const auto& a_shape = + Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + const auto& b_shape = + Downcast(GetStructInfo(call->args[1]))->GetShape().value_or(empty); + + if (a_shape.size() == 0) { + return InferLayoutOutput(); + } + LayoutDecision a_layout = InferLayoutDecision(call->args[0], var_layout_map); + if (!a_layout->layout.defined()) { + if (a_shape.size() == 4) { + a_layout = LayoutDecision("NCHW"); + } else if (a_shape.size() == 3) { + a_layout = LayoutDecision("NCD"); + } else if (a_shape.size() == 2) { + a_layout = LayoutDecision("NC"); + } + } + size_t start = a_layout->layout.ndim() - b_shape.size(); + String pre_layout; + for (size_t i = start; i < a_layout->layout.ndim() - 2; i++) { + pre_layout = pre_layout + a_layout->layout[i].name(); + } + LayoutDecision b_layout = LayoutDecision(pre_layout + "IO"); + return InferLayoutOutput({a_layout, b_layout}, {a_layout}, Attrs()); +} + +InferLayoutOutput ForwardInferLayoutPermute(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + LayoutDecision input_layout = InferLayoutDecision(call->args[0], var_layout_map); + if (!input_layout->layout.defined()) { + return InferLayoutOutput(); + } + std::vector permute_axes; + const auto* attrs = call->attrs.as(); + if (!attrs->axes.defined()) { + for (size_t i = input_layout->layout.ndim(); i > 0; i--) { + permute_axes.push_back(i - 1); + } + } else { + for (const auto& a : attrs->axes.value()) { + permute_axes.push_back(a->value); + } + } + LayoutDecision output_layout = LayoutUtils::PermuteLayout(input_layout, permute_axes); + return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); +} + +InferLayoutOutput ForwardInferLayoutReduceAxis(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + LayoutDecision input_layout = InferLayoutDecision(call->args[0], var_layout_map); + if (!input_layout->layout.defined()) { + return InferLayoutOutput(); + } + const auto* attrs = call->attrs.as(); + if (attrs->keepdims) { + return InferLayoutOutput({input_layout}, {input_layout}, Attrs()); + } + if (!attrs->axis.defined()) { + return InferLayoutOutput({input_layout}, {LayoutDecision("")}, Attrs()); + } + Array empty; + const auto& input_shape = + Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + if (input_shape.size() == 0) { + return InferLayoutOutput(); + } + std::vector axes; + for (const auto& s : attrs->axis.value()) { + axes.push_back(CommonUtils::GetIndex(s->value, input_shape.size())); + } + LayoutDecision output_layout = LayoutUtils::ReduceLayout(input_layout, axes); + return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); +} + +InferLayoutOutput ForwardInferLayoutReshape(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + LayoutDecision input_layout = InferLayoutDecision(call->args[0], var_layout_map); + if (!input_layout->layout.defined()) { + return InferLayoutOutput(); + } + Array empty; + const auto& input_shape = + Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + const auto& output_shape = + Downcast(GetStructInfo(call))->GetShape().value_or(empty); + if (input_shape.size() == 0 || output_shape.size() == 0) { + return InferLayoutOutput(); + } + LayoutDecision output_layout; + if (input_shape.size() == output_shape.size()) { + output_layout = input_layout; + } else if (input_shape.size() > output_shape.size()) { + const auto& reduce_axes = InferReduceAxes(input_shape, output_shape); + if (reduce_axes.size() == 0) { + return InferLayoutOutput(); + } + output_layout = LayoutUtils::ReduceLayout(input_layout, reduce_axes); + } else { + const auto& expand_axes = InferExpandAxes(input_shape, output_shape); + if (expand_axes.size() == 0) { + return InferLayoutOutput(); + } + output_layout = LayoutUtils::ExpandLayout(input_layout, expand_axes); + } + return InferLayoutOutput({input_layout, LayoutDecision("O")}, {output_layout}, Attrs()); +} + +InferLayoutOutput ForwardInferLayoutSqueeze(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + LayoutDecision input_layout = InferLayoutDecision(call->args[0], var_layout_map); + if (!input_layout->layout.defined()) { + return InferLayoutOutput(); + } + Array empty; + const auto& input_shape = + Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + if (input_shape.size() == 0) { + return InferLayoutOutput(); + } + const auto* attrs = call->attrs.as(); + std::vector reduce_axes; + if (attrs->axis.defined()) { + for (const auto& s : attrs->axis.value()) { + size_t v_index = CommonUtils::GetIndex(s->value, input_shape.size()); + if (Downcast(input_shape[v_index])->value == 1) { + reduce_axes.push_back(v_index); + } + } + } else { + for (size_t i = 0; i < input_shape.size(); i++) { + if (Downcast(input_shape[i])->value == 1) { + reduce_axes.push_back(i); + } + } + } + LayoutDecision output_layout = LayoutUtils::ReduceLayout(input_layout, reduce_axes); + return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); +} + +InferLayoutOutput ForwardInferLayoutTake(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + LayoutDecision input_layout = InferLayoutDecision(call->args[1], var_layout_map); + if (!input_layout->layout.defined()) { + return InferLayoutOutput(); + } + LayoutDecision output_layout = LayoutUtils::ExpandLayout(input_layout, std::vector{0}); + return InferLayoutOutput({LayoutDecision("WE"), input_layout}, {output_layout}, Attrs()); +} + +TVM_REGISTER_OP("relax.nn.conv1d") + .set_attr("FMSCForwardInferLayout", MSCInferLayoutConv); +TVM_REGISTER_OP("relax.nn.conv2d") + .set_attr("FMSCForwardInferLayout", MSCInferLayoutConv); +TVM_REGISTER_OP("relax.nn.max_pool2d") + .set_attr("FMSCForwardInferLayout", MSCInferLayoutPool2d); +TVM_REGISTER_OP("relax.nn.avg_pool2d") + .set_attr("FMSCForwardInferLayout", MSCInferLayoutPool2d); +TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d") + .set_attr("FMSCForwardInferLayout", MSCInferLayoutPool2d); +// reduce axis ops +TVM_REGISTER_OP("relax.argmax") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutReduceAxis); +TVM_REGISTER_OP("relax.argmin") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutReduceAxis); +TVM_REGISTER_OP("relax.max") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutReduceAxis); +TVM_REGISTER_OP("relax.min") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutReduceAxis); +TVM_REGISTER_OP("relax.mean") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutReduceAxis); +TVM_REGISTER_OP("relax.sum") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutReduceAxis); +TVM_REGISTER_OP("relax.prod") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutReduceAxis); +TVM_REGISTER_OP("relax.std") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutReduceAxis); +// binary ops +TVM_REGISTER_OP("relax.add") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); +TVM_REGISTER_OP("relax.divide") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); +TVM_REGISTER_OP("relax.floor_divide") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); +TVM_REGISTER_OP("relax.multiply") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); +TVM_REGISTER_OP("relax.power") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); +TVM_REGISTER_OP("relax.subtract") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); +TVM_REGISTER_OP("relax.equal") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); +TVM_REGISTER_OP("relax.greater") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); +TVM_REGISTER_OP("relax.greater_equal") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); +TVM_REGISTER_OP("relax.less") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); +TVM_REGISTER_OP("relax.less_equal") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); +TVM_REGISTER_OP("relax.not_equal") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); +TVM_REGISTER_OP("relax.maximum") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); +TVM_REGISTER_OP("relax.minimum") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); +TVM_REGISTER_OP("relax.logical_and") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); +TVM_REGISTER_OP("relax.logical_or") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); +TVM_REGISTER_OP("relax.logical_xor") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); +TVM_REGISTER_OP("relax.bitwise_and") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); +TVM_REGISTER_OP("relax.bitwise_or") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); +TVM_REGISTER_OP("relax.bitwise_xor") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); +// math ops +TVM_REGISTER_OP("relax.expand_dims") + .set_attr("FMSCForwardInferLayout", ForkwardInferLayoutExpandDims); +TVM_REGISTER_OP("relax.matmul") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutMatmul); +TVM_REGISTER_OP("relax.permute_dims") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutPermute); +TVM_REGISTER_OP("relax.reshape") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutReshape); +TVM_REGISTER_OP("relax.squeeze") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutSqueeze); +TVM_REGISTER_OP("relax.take") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutTake); +// nn ops +TVM_REGISTER_OP("relax.nn.batch_norm") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBatchNorm); +TVM_REGISTER_OP("relax.nn.group_norm") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutNormalize); +TVM_REGISTER_OP("relax.nn.layer_norm") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutNormalize); + +// Backward Infer +InferLayoutOutput BackwardInferLayoutCommon(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + NLayout output_layout = InferNLayout(call, var_layout_map); + LayoutDecision layout_hint; + if (output_layout.IsLeaf()) { + layout_hint = output_layout.LeafValue(); + } else { + for (const auto& l : output_layout.NestedArray()) { + if (l.IsLeaf() && l.LeafValue()->layout.defined()) { + layout_hint = l.LeafValue(); + } + } + } + if (!layout_hint->layout.defined()) { + return InferLayoutOutput(); + } + Array input_layouts; + for (const auto& arg : call->args) { + const auto& saved_layout = InferLayoutDecision(arg, var_layout_map); + if (saved_layout->layout.defined()) { + input_layouts.push_back(saved_layout); + } else { + input_layouts.push_back(layout_hint); + } + } + return InferLayoutOutput(input_layouts, {output_layout}, Attrs()); +} + +InferLayoutOutput BackwardInferLayoutBinary(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + const auto& output = BackwardInferLayoutCommon(call, desired_layouts, var_layout_map); + if (!output.defined()) { + return output; + } + std::vector input_layouts; + for (size_t i = 0; i < call->args.size(); i++) { + const auto& sinfo = GetStructInfo(call->args[i]); + if (const auto* t_info = sinfo.as()) { + if (t_info->ndim == 0) { + input_layouts.push_back(LayoutDecision("")); + } else { + input_layouts.push_back(output->input_layouts[i]); + } + } else { + LOG(FATAL) << "Binary input should be tensor, get " << sinfo->GetTypeKey(); + } + } + return InferLayoutOutput(input_layouts, output->output_layouts, Attrs()); +} + +InferLayoutOutput BackwardInferLayoutInplace(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + return BackwardInferLayoutCommon(call, desired_layouts, var_layout_map); +} + +InferLayoutOutput BackwardInferLayoutBatchNorm(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + LayoutDecision output_layout = InferLayoutDecisionAt(call, var_layout_map, 0); + if (!output_layout->layout.defined()) { + return InferLayoutOutput(); + } + LayoutDecision g_layout = LayoutDecision("O"); + return InferLayoutOutput({output_layout, g_layout, g_layout, g_layout, g_layout}, + {{output_layout, g_layout, g_layout}}, Attrs()); +} + +InferLayoutOutput BackwardInferLayoutExpandDims(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map); + if (!output_layout->layout.defined()) { + return InferLayoutOutput(); + } + Array empty; + const auto& input_shape = + Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + if (input_shape.size() == 0) { + return InferLayoutOutput(); + } + const auto* attrs = call->attrs.as(); + std::vector expand_axes; + for (const auto& s : attrs->axis) { + expand_axes.push_back(CommonUtils::GetIndex(s->value, input_shape.size())); + } + LayoutDecision input_layout = LayoutUtils::ReduceLayout(output_layout, expand_axes); + return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); +} + +InferLayoutOutput BackwardInferLayoutNormalize(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + LayoutDecision output_layout = InferLayoutDecisionAt(call, var_layout_map, 0); + if (!output_layout->layout.defined()) { + return InferLayoutOutput(); + } + LayoutDecision g_layout = LayoutDecision("O"); + return InferLayoutOutput({output_layout, g_layout, g_layout}, {output_layout}, Attrs()); +} + +InferLayoutOutput BackwardInferLayoutMatmul(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map); + if (!output_layout->layout.defined()) { + return InferLayoutOutput(); + } + Array empty; + const auto& b_shape = + Downcast(GetStructInfo(call->args[1]))->GetShape().value_or(empty); + if (b_shape.size() == 0) { + return InferLayoutOutput(); + } + size_t start = output_layout->layout.ndim() - b_shape.size(); + String pre_layout; + for (size_t i = start; i < output_layout->layout.ndim() - 2; i++) { + pre_layout = pre_layout + output_layout->layout[i].name(); + } + LayoutDecision b_layout = LayoutDecision(pre_layout + "IO"); + return InferLayoutOutput({output_layout, b_layout}, {output_layout}, Attrs()); +} + +InferLayoutOutput BackwardInferLayoutPermute(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map); + if (!output_layout->layout.defined()) { + return InferLayoutOutput(); + } + std::vector permute_axes; + const auto* attrs = call->attrs.as(); + if (!attrs->axes.defined()) { + for (size_t i = output_layout->layout.ndim(); i > 0; i--) { + permute_axes.push_back(i - 1); + } + } else { + std::vector attr_axes; + for (const auto& s : attrs->axes.value()) { + attr_axes.push_back(s->value); + } + for (size_t i = 0; i < output_layout->layout.ndim(); i++) { + int pos = ArrayUtils::IndexOf(attr_axes, static_cast(i)); + if (pos >= 0) { + permute_axes.push_back(pos); + } else { + permute_axes.push_back(i); + } + } + } + LayoutDecision input_layout = LayoutUtils::PermuteLayout(output_layout, permute_axes); + return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); +} + +InferLayoutOutput BackwardInferLayoutReduceAxis(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map); + if (!output_layout->layout.defined()) { + return InferLayoutOutput(); + } + const auto* attrs = call->attrs.as(); + if (attrs->keepdims) { + return InferLayoutOutput({output_layout}, {output_layout}, Attrs()); + } + Array empty; + const auto& input_shape = + Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + if (input_shape.size() == 0) { + return InferLayoutOutput(); + } + std::vector axes; + for (const auto& s : attrs->axis.value()) { + axes.push_back(CommonUtils::GetIndex(s->value, input_shape.size())); + } + LayoutDecision input_layout = LayoutUtils::ExpandLayout(output_layout, axes); + return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); +} + +InferLayoutOutput BackwardInferLayoutReshape(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map); + if (!output_layout->layout.defined()) { + return InferLayoutOutput(); + } + Array empty; + const auto& input_shape = + Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + const auto& output_shape = + Downcast(GetStructInfo(call))->GetShape().value_or(empty); + if (input_shape.size() == 0 || output_shape.size() == 0) { + return InferLayoutOutput(); + } + LayoutDecision input_layout; + if (input_shape.size() == output_shape.size()) { + input_layout = output_layout; + } else if (input_shape.size() > output_shape.size()) { + const auto& reduce_axes = InferReduceAxes(input_shape, output_shape); + if (reduce_axes.size() == 0) { + return InferLayoutOutput(); + } + input_layout = LayoutUtils::ExpandLayout(output_layout, reduce_axes); + } else { + const auto& expand_axes = InferExpandAxes(input_shape, output_shape); + if (expand_axes.size() == 0) { + return InferLayoutOutput(); + } + input_layout = LayoutUtils::ReduceLayout(output_layout, expand_axes); + } + return InferLayoutOutput({input_layout, LayoutDecision("O")}, {output_layout}, Attrs()); +} + +InferLayoutOutput BackwardInferLayoutSqueeze(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map); + if (!output_layout->layout.defined()) { + return InferLayoutOutput(); + } + Array empty; + const auto& input_shape = + Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + if (input_shape.size() == 0) { + return InferLayoutOutput(); + } + const auto* attrs = call->attrs.as(); + std::vector reduce_axes; + if (attrs->axis.defined()) { + for (const auto& s : attrs->axis.value()) { + size_t v_index = CommonUtils::GetIndex(s->value, input_shape.size()); + if (Downcast(input_shape[v_index])->value == 1) { + reduce_axes.push_back(v_index); + } + } + } else { + for (size_t i = 0; i < input_shape.size(); i++) { + if (Downcast(input_shape[i])->value == 1) { + reduce_axes.push_back(i); + } + } + } + LayoutDecision input_layout = LayoutUtils::ExpandLayout(output_layout, reduce_axes); + return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); +} + +InferLayoutOutput BackwardInferLayoutTake(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map); + if (!output_layout->layout.defined()) { + return InferLayoutOutput(); + } + LayoutDecision input_layout = LayoutUtils::ReduceLayout(output_layout, std::vector{0}); + return InferLayoutOutput({LayoutDecision("WE"), input_layout}, {output_layout}, Attrs()); +} + +TVM_REGISTER_OP("relax.nn.conv1d") + .set_attr("FMSCBackwardInferLayout", MSCInferLayoutConv); +TVM_REGISTER_OP("relax.nn.conv2d") + .set_attr("FMSCBackwardInferLayout", MSCInferLayoutConv); +TVM_REGISTER_OP("relax.nn.max_pool2d") + .set_attr("FMSCBackwardInferLayout", MSCInferLayoutPool2d); +TVM_REGISTER_OP("relax.nn.avg_pool2d") + .set_attr("FMSCBackwardInferLayout", MSCInferLayoutPool2d); +TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d") + .set_attr("FMSCBackwardInferLayout", MSCInferLayoutPool2d); +// reduce axis ops +TVM_REGISTER_OP("relax.argmax") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutReduceAxis); +TVM_REGISTER_OP("relax.argmin") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutReduceAxis); +TVM_REGISTER_OP("relax.max") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutReduceAxis); +TVM_REGISTER_OP("relax.min") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutReduceAxis); +TVM_REGISTER_OP("relax.mean") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutReduceAxis); +TVM_REGISTER_OP("relax.sum") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutReduceAxis); +TVM_REGISTER_OP("relax.prod") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutReduceAxis); +TVM_REGISTER_OP("relax.std") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutReduceAxis); +// binary ops +TVM_REGISTER_OP("relax.add") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); +TVM_REGISTER_OP("relax.divide") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); +TVM_REGISTER_OP("relax.floor_divide") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); +TVM_REGISTER_OP("relax.multiply") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); +TVM_REGISTER_OP("relax.power") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); +TVM_REGISTER_OP("relax.subtract") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); +TVM_REGISTER_OP("relax.equal") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); +TVM_REGISTER_OP("relax.greater") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); +TVM_REGISTER_OP("relax.greater_equal") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); +TVM_REGISTER_OP("relax.less") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); +TVM_REGISTER_OP("relax.less_equal") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); +TVM_REGISTER_OP("relax.not_equal") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); +TVM_REGISTER_OP("relax.maximum") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); +TVM_REGISTER_OP("relax.minimum") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); +TVM_REGISTER_OP("relax.logical_and") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); +TVM_REGISTER_OP("relax.logical_or") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); +TVM_REGISTER_OP("relax.logical_xor") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); +TVM_REGISTER_OP("relax.bitwise_and") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); +TVM_REGISTER_OP("relax.bitwise_or") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); +TVM_REGISTER_OP("relax.bitwise_xor") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); +// math ops +TVM_REGISTER_OP("relax.expand_dims") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutExpandDims); +TVM_REGISTER_OP("relax.matmul") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutMatmul); +TVM_REGISTER_OP("relax.permute_dims") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutPermute); +TVM_REGISTER_OP("relax.reshape") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutReshape); +TVM_REGISTER_OP("relax.squeeze") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutSqueeze); +TVM_REGISTER_OP("relax.take") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutTake); +// nn ops +TVM_REGISTER_OP("relax.nn.batch_norm") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBatchNorm); +TVM_REGISTER_OP("relax.nn.group_norm") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutNormalize); +TVM_REGISTER_OP("relax.nn.layer_norm") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutNormalize); + +class LayoutInfer : public ExprVisitor { + public: + explicit LayoutInfer(const IRModule& ref_module) : ref_module_(ref_module) { Reset(); } + + void Reset() { + infered_ = false; + var_map_.clear(); + ordered_exprs_.clear(); + } + + void RecordExpr(const Var& var, const Expr& expr) { + var_map_.Set(var, expr); + ordered_exprs_.push_back(expr); + } + + Expr Infer(const Expr& expr) { + Reset(); + ForwardInfer(expr); + BackwardInfer(); + return expr; + } + + void ForwardInfer(const Expr& expr) { ExprVisitor::VisitExpr(expr); } + + void BackwardInfer() { + for (size_t e_idx = ordered_exprs_.size(); e_idx > 0; e_idx--) { + const Expr& expr = ordered_exprs_[e_idx - 1]; + if (const auto* t_node = expr.as()) { + continue; + } + if (const auto* t_node = expr.as()) { + continue; + } + if (!expr.as()) { + continue; + } + const Call& call = Downcast(expr); + size_t infered_num = 0; + for (const auto& arg : call->args) { + if (arg.as() && var_map_.count(Downcast(arg))) { + if (LayoutUtils::LayoutInfered(var_map_[Downcast(arg)]) > 0) { + infered_num++; + } + } else if (LayoutUtils::LayoutInfered(arg)) { + infered_num++; + } + } + if (call->args.size() == 0 || infered_num == call->args.size() || !call->op.as() || + LayoutUtils::HasUnknownDimTensor(call->args)) { + continue; + } + const OpNode* op_node = call->op.as(); + if (op_node == nullptr) { + continue; + } + // Infer by op_node + Op op = Downcast(GetRef(op_node)); + InferLayoutOutput infered_layout; + const auto msc_infer_map = Op::GetAttrMap("FMSCBackwardInferLayout"); + try { + if (msc_infer_map.count(op)) { + FRelaxInferLayout f = msc_infer_map[op]; + infered_layout = f(call, Map>(), var_layout_map_); + } else { + infered_layout = + BackwardInferLayoutCommon(call, Map>(), var_layout_map_); + } + } catch (runtime::InternalError& err) { + LOG(WARNING) << "Failed to backward infer layout " << expr << " : " << err.message(); + infered_layout = InferLayoutOutput(); + } + try { + if (infered_layout.defined()) { + SetInputLayouts(infered_layout->input_layouts, call); + } + } catch (runtime::InternalError& err) { + LOG(WARNING) << "Failed to backward set inputs layout for " << call << " : " + << err.message(); + } + } + } + + void SetInputLayouts(const Array& input_layouts, const Call& call) { + if (input_layouts.size() == call->args.size()) { + for (size_t i = 0; i < input_layouts.size(); i++) { + if (call->args[i].as()) { + const auto& var = Downcast(call->args[i]); + var_layout_map_[var] = input_layouts[i]; + if (var_map_.count(var)) { + if (LayoutUtils::SetLayout(var_map_[var], input_layouts[i])) { + infered_ = true; + } + } else if (LayoutUtils::SetLayout(var, input_layouts[i])) { + infered_ = true; + } + } else if (LayoutUtils::SetLayout(call->args[i], input_layouts[i])) { + infered_ = true; + } + } + } + } + + void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { + ExprVisitor::VisitBinding_(binding, call_node); + const auto& call = GetRef(call_node); + if (const auto* v_node = call->op.as()) { + // infer global func and set var layouts + const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); + Infer(func); + for (size_t i = 0; i < func->params.size(); i++) { + if (var_layout_map_.count(func->params[i]) && + LayoutUtils::SetLayout(call->args[i], var_layout_map_[func->params[i]])) { + infered_ = true; + } + } + if (const auto* b_node = func->body.as()) { + var_layout_map_[binding->var] = GetNLayout(var_layout_map_, b_node->body); + if (LayoutUtils::SetLayout(call, var_layout_map_[binding->var])) { + infered_ = true; + } + } else { + LOG(FATAL) << "Function body should be SeqExpr, get " << func->body; + } + } else { + // infer call + bool infer_outputs = true; + RecordExpr(binding->var, call); + if (LayoutUtils::LayoutInfered(call)) { + infer_outputs = false; + } + if (call->args.size() == 0 || !call->op.as() || + LayoutUtils::HasUnknownDimTensor(call->args)) { + infer_outputs = false; + } + const OpNode* op_node = call->op.as(); + if (op_node == nullptr) { + infer_outputs = false; + } + if (infer_outputs) { + // infer layouts + Op op = Downcast(GetRef(op_node)); + InferLayoutOutput infered_layout; + const auto msc_infer_map = Op::GetAttrMap("FMSCForwardInferLayout"); + const auto relax_infer_map = Op::GetAttrMap("FRelaxInferLayout"); + bool set_inputs = true; + try { + if (msc_infer_map.count(op)) { + FRelaxInferLayout f = msc_infer_map[op]; + infered_layout = f(call, Map>(), var_layout_map_); + } else if (!relax_infer_map.count(op)) { + infered_layout = + ForwardInferLayoutCommon(call, Map>(), var_layout_map_); + } + if (relax_infer_map.count(op) && !infered_layout.defined()) { + FRelaxInferLayout f = relax_infer_map[op]; + infered_layout = f(call, Map>(), var_layout_map_); + set_inputs = false; + } + } catch (runtime::InternalError& err) { + LOG(WARNING) << "Failed to forward infer layout for " << binding->var << " : " + << binding->value << ", reason: " << err.message(); + infered_layout = InferLayoutOutput(); + } + if (infered_layout.defined() && infered_layout->output_layouts.size() == 1) { + try { + var_layout_map_[binding->var] = infered_layout->output_layouts[0]; + if (LayoutUtils::SetLayout(call, var_layout_map_[binding->var])) { + infered_ = true; + } + } catch (runtime::InternalError& err) { + LOG(WARNING) << "Failed to forward set output layout for " << binding->var << " : " + << binding->value << ", reason: " << err.message(); + } + } + if (set_inputs && infered_layout.defined()) { + try { + SetInputLayouts(infered_layout->input_layouts, call); + } catch (runtime::InternalError& err) { + LOG(WARNING) << "Failed to forward set inputs layout for " << call << " : " + << err.message(); + } + } + } + } + } + + void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) final { + ExprVisitor::VisitBinding_(binding, val); + std::vector input_layout; + for (const auto& field : val->fields) { + if (binding->var->IsInstance()) { + // Df var: Use the current realized layout to group the tuple; + input_layout.push_back(GetNLayout(var_layout_map_, field)); + } else { + // Global var: Use the initial layout to group the tuple; + input_layout.push_back(InitialNLayout(field)); + } + } + if (IsNestedTensor(binding->var)) { + var_layout_map_[binding->var] = input_layout; + } + RecordExpr(binding->var, GetRef(val)); + } + + void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) final { + ExprVisitor::VisitBinding_(binding, val); + NLayout input_layout = binding->var->IsInstance() + ? GetNLayout(var_layout_map_, val->tuple) + : InitialNLayout(val->tuple); + var_layout_map_[binding->var] = input_layout.NestedArray()[val->index]; + RecordExpr(binding->var, GetRef(val)); + } + + void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val) final { + ExprVisitor::VisitBinding_(binding, val); + const NLayout& out_layout = LayoutDecision("O"); + var_layout_map_[binding->var] = out_layout; + if (LayoutUtils::SetLayout(GetRef(val), out_layout)) { + infered_ = true; + } + } + + bool infered() { return infered_; } + + private: + IRModule ref_module_; + bool infered_; + Map var_map_; + Array ordered_exprs_; + std::unordered_map var_layout_map_; +}; // class LayoutInfer + +class LayoutChecker : public ExprVisitor { + public: + LayoutChecker() { missing_num_ = 0; } + + void Check(const Expr& expr) { + ExprVisitor::VisitExpr(expr); + ICHECK_EQ(missing_num_, 0) << "Some layout is missing"; + } + + void VisitExpr_(const CallNode* call) final { + ExprVisitor::VisitExpr_(call); + if (!LayoutUtils::LayoutInfered(GetRef(call))) { + missing_num_++; + } + } + + void VisitExpr_(const ConstantNode* cn) final { + ExprVisitor::VisitExpr_(cn); + if (!LayoutUtils::LayoutInfered(GetRef(cn))) { + missing_num_++; + } + } + + private: + size_t missing_num_; +}; // class LayoutChecker + +void SetExprLayout(const IRModule& ref_module, const Expr& func, bool allow_missing) { + auto layout_infer = LayoutInfer(ref_module); + auto new_func = layout_infer.Infer(func); + if (!allow_missing) { + LayoutChecker().Check(new_func); + } +} + +namespace transform { + +Pass SetExprLayout(bool allow_missing, const String& entry_name) { + runtime::TypedPackedFunc pass_func = [=](IRModule m, + PassContext pc) { + relax::SetExprLayout(m, m->Lookup(entry_name), allow_missing); + return m; + }; + return CreateModulePass(pass_func, 0, "SetExprLayout", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.SetExprLayout").set_body_typed(SetExprLayout); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/contrib/msc/core/transform/set_expr_name.cc b/src/contrib/msc/core/transform/set_expr_name.cc new file mode 100644 index 000000000000..5b39a5a7ac1b --- /dev/null +++ b/src/contrib/msc/core/transform/set_expr_name.cc @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/core/transform/set_expr_name.cc + * \brief Pass for setting name for call and constant. + */ + +#include +#include +#include +#include +#include +#include + +#include "../utils.h" + +namespace tvm { +using namespace tvm::contrib::msc; + +namespace relax { + +/*! + * \brief Name setter for Relax + */ +class RelaxExprNameSetter : public ExprVisitor { + public: + explicit RelaxExprNameSetter(const IRModule& ref_module) : ref_module_(ref_module) {} + + void VisitBindingBlock(const BindingBlock& block) final { + String block_name = SpanUtils::GetAttr(block->span, "name"); + if (block_name.size() == 0) { + block_name = "block"; + } + if (setted_blocks_.count(block_name)) { + int cnt = 1; + while (setted_blocks_.count(block_name + "_" + std::to_string(cnt))) { + cnt++; + } + block_name = block_name + "_" + std::to_string(cnt); + } + setted_blocks_.insert(block_name); + block_stack_.push_back(block_name); + const String& unique_name = StringUtils::Join(block_stack_, "."); + block->span = SpanUtils::SetAttr(block->span, "name", unique_name); + ExprVisitor::VisitBindingBlock(block); + block_stack_.pop_back(); + } + + void VisitExpr_(const ConstantNode* val) { + ExprVisitor::VisitExpr_(val); + const String& unique_name = GetUniqueName(GetRef(val), "const"); + if (unique_name != SpanUtils::GetAttr(val->span, "name")) { + val->span = SpanUtils::SetAttr(val->span, "name", unique_name); + } + expr_names_.Set(GetRef(val), unique_name); + } + + void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val) { + ExprVisitor::VisitBinding_(binding, val); + const String& unique_name = GetUniqueName(GetRef(val), "const"); + if (unique_name != SpanUtils::GetAttr(val->span, "name")) { + val->span = SpanUtils::SetAttr(val->span, "name", unique_name); + } + expr_names_.Set(binding->var, unique_name); + } + + void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val) { + ExprVisitor::VisitBinding_(binding, val); + const String& unique_name = GetUniqueName(GetRef(val), "shape"); + if (unique_name != SpanUtils::GetAttr(val->span, "name")) { + val->span = SpanUtils::SetAttr(val->span, "name", unique_name); + } + expr_names_.Set(binding->var, unique_name); + } + + void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) { + ExprVisitor::VisitBinding_(binding, val); + const String& unique_name = GetUniqueName(GetRef(val), "tuple"); + if (unique_name != SpanUtils::GetAttr(val->span, "name")) { + val->span = SpanUtils::SetAttr(val->span, "name", unique_name); + } + expr_names_.Set(binding->var, unique_name); + } + + void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) { + ExprVisitor::VisitBinding_(binding, val); + ICHECK(expr_names_.count(val->tuple)) << "Can not find tuple of " << GetRef(val); + const String& unique_name = expr_names_[val->tuple] + "." + std::to_string(val->index); + if (unique_name != SpanUtils::GetAttr(val->span, "name")) { + val->span = SpanUtils::SetAttr(val->span, "name", unique_name); + } + expr_names_.Set(binding->var, unique_name); + } + + void VisitBinding_(const VarBindingNode* binding, const CallNode* val) { + ExprVisitor::VisitBinding_(binding, val); + String name_hint, optype; + if (const auto* op_node = val->op.as()) { + const std::string& op_name = op_node->name; + int rpos = op_name.rfind("."); + name_hint = op_name.substr(rpos + 1); + optype = StringUtils::Replace(op_node->name, "relax.", ""); + } else if (const auto* v_node = val->op.as()) { + const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); + ExprVisitor::VisitExpr(func); + const auto& name_opt = func->GetAttr(attr::kComposite); + ICHECK(name_opt.defined()) << "Unexpected global func without composite"; + name_hint = name_opt.value(); + optype = name_hint; + } + // set name + const String& unique_name = GetUniqueName(GetRef(val), name_hint); + if (unique_name != SpanUtils::GetAttr(val->span, "name")) { + val->span = SpanUtils::SetAttr(val->span, "name", unique_name); + } + // set constant consumer && master + Array input_types; + try { + input_types = ExprUtils::GetInputTypes(optype, val->args.size(), true); + } catch (runtime::InternalError& err) { + LOG(WARNING) << "Failed to GetInputTypes for " << GetRef(val) << " : " << err.message(); + throw err; + } + for (size_t i = 0; i < input_types.size(); i++) { + if (input_types[i] == "input") { + continue; + } + if (const auto* c_node = val->args[i].as()) { + const String& const_name = SpanUtils::GetAttr(c_node->span, "name"); + if (constant_consumers_.count(const_name)) { + val->span = SpanUtils::SetAttr(val->span, "master", constant_consumers_[const_name]); + } else { + constant_consumers_.Set(const_name, unique_name); + } + } + } + expr_names_.Set(binding->var, unique_name); + } + + private: + const String GetUniqueName(const Expr& expr, const String& name_hint) { + String expr_name = SpanUtils::GetAttr(expr->span, "name"); + if (expr_name.size() == 0) { + expr_name = name_hint; + } + if (!setted_names_.count(expr_name)) { + setted_names_.Set(expr_name, expr); + return expr_name; + } + if (setted_names_[expr_name] == expr) { + return expr_name; + } + int cnt = 1; + while (setted_names_.count(expr_name + "_" + std::to_string(cnt)) && + setted_names_[expr_name + "_" + std::to_string(cnt)] != expr) { + cnt++; + } + expr_name = expr_name + "_" + std::to_string(cnt); + if (!setted_names_.count(expr_name)) { + setted_names_.Set(expr_name, expr); + } + return expr_name; + } + + Map setted_names_; + Map constant_consumers_; + std::set setted_blocks_; + Array block_stack_; + Map expr_names_; + IRModule ref_module_; +}; // class ExprNameSetter + +void SetRelaxExprName(const IRModule& ref_module, const Expr& e) { + RelaxExprNameSetter(ref_module).VisitExpr(e); +} + +namespace transform { + +Pass SetRelaxExprName(const String& entry_name) { + runtime::TypedPackedFunc pass_func = [=](IRModule m, + PassContext pc) { + relax::SetRelaxExprName(m, m->Lookup(entry_name)); + return m; + }; + return CreateModulePass(pass_func, 0, "SetRelaxExprName", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.SetRelaxExprName").set_body_typed(SetRelaxExprName); + +} // namespace transform +} // namespace relax + +namespace relay { + +/*! + * \brief Name setter for Relay + */ +class RelayExprNameSetter : public ExprVisitor { + public: + explicit RelayExprNameSetter(const IRModule& ref_module) : ref_module_(ref_module) {} + + void VisitExpr_(const ConstantNode* op) final { + ExprVisitor::VisitExpr_(op); + const String& unique_name = GetUniqueName(GetRef(op), "const"); + if (unique_name != SpanUtils::GetAttr(op->span, "name")) { + op->span = SpanUtils::SetAttr(op->span, "name", unique_name); + } + } + + void VisitExpr_(const TupleNode* op) final { + ExprVisitor::VisitExpr_(op); + const String& unique_name = GetUniqueName(GetRef(op), "tuple"); + if (unique_name != SpanUtils::GetAttr(op->span, "name")) { + op->span = SpanUtils::SetAttr(op->span, "name", unique_name); + } + } + + void VisitExpr_(const TupleGetItemNode* op) final { + ExprVisitor::VisitExpr_(op); + const String& tuple_name = SpanUtils::GetAttr(op->tuple->span, "name"); + const String& unique_name = tuple_name + "." + std::to_string(op->index); + if (unique_name != SpanUtils::GetAttr(op->span, "name")) { + op->span = SpanUtils::SetAttr(op->span, "name", unique_name); + } + } + + void VisitExpr_(const FunctionNode* op) final { + ExprVisitor::VisitExpr_(op); + const auto& name_opt = op->GetAttr(attr::kComposite); + const String& name_hint = name_opt.defined() ? name_opt.value() : "func"; + const String& unique_name = GetUniqueName(GetRef(op), name_hint); + if (unique_name != SpanUtils::GetAttr(op->span, "name")) { + op->span = SpanUtils::SetAttr(op->span, "name", unique_name); + } + } + + void VisitExpr_(const CallNode* op) final { + ExprVisitor::VisitExpr_(op); + String name_hint, optype; + if (const auto* op_node = op->op.as()) { + const std::string& op_name = op_node->name; + int rpos = op_name.rfind("."); + name_hint = op_name.substr(rpos + 1); + optype = StringUtils::Replace(op_node->name, "relay.", ""); + } else if (const auto* v_node = op->op.as()) { + const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); + ExprVisitor::VisitExpr(func); + const auto& name_opt = func->GetAttr(attr::kComposite); + ICHECK(name_opt.defined()) << "Unexpected global func without composite"; + optype = name_opt.value(); + name_hint = optype; + } + // set name + const String& unique_name = GetUniqueName(GetRef(op), name_hint); + if (unique_name != SpanUtils::GetAttr(op->span, "name")) { + op->span = SpanUtils::SetAttr(op->span, "name", unique_name); + } + // set constant consumer && master + Array input_types; + try { + input_types = ExprUtils::GetInputTypes(optype, op->args.size(), false); + } catch (runtime::InternalError& err) { + LOG(WARNING) << "Failed to GetInputTypes for " << GetRef(op) << " : " << err.message(); + throw err; + } + for (size_t i = 0; i < input_types.size(); i++) { + if (input_types[i] == "input") { + continue; + } + if (const auto* c_node = op->args[i].as()) { + const String& const_name = SpanUtils::GetAttr(c_node->span, "name"); + if (constant_consumers_.count(const_name)) { + op->span = SpanUtils::SetAttr(op->span, "master", constant_consumers_[const_name]); + } else { + constant_consumers_.Set(const_name, unique_name); + } + } + } + } + + private: + const String GetUniqueName(const Expr& expr, const String& name_hint) { + String expr_name = SpanUtils::GetAttr(expr->span, "name"); + if (expr_name.size() == 0) { + expr_name = name_hint; + } + if (!setted_names_.count(expr_name)) { + setted_names_.Set(expr_name, expr); + return expr_name; + } + if (setted_names_[expr_name] == expr) { + return expr_name; + } + int cnt = 1; + while (setted_names_.count(expr_name + "_" + std::to_string(cnt)) && + setted_names_[expr_name + "_" + std::to_string(cnt)] != expr) { + cnt++; + } + expr_name = expr_name + "_" + std::to_string(cnt); + if (!setted_names_.count(expr_name)) { + setted_names_.Set(expr_name, expr); + } + return expr_name; + } + + Map setted_names_; + Map constant_consumers_; + IRModule ref_module_; +}; // class ExprNameSetter + +void SetRelayExprName(const IRModule& ref_module, const Expr& e) { + RelayExprNameSetter(ref_module).VisitExpr(e); +} + +namespace transform { + +Pass SetRelayExprName(const String& entry_name) { + runtime::TypedPackedFunc pass_func = [=](IRModule m, + PassContext pc) { + relay::SetRelayExprName(m, m->Lookup(entry_name)); + return m; + }; + return CreateModulePass(pass_func, 0, "SetRelayExprName", {}); +} + +TVM_REGISTER_GLOBAL("relay._transform.SetRelayExprName").set_body_typed(SetRelayExprName); + +} // namespace transform +} // namespace relay + +} // namespace tvm diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc new file mode 100644 index 000000000000..66bbf8cc9790 --- /dev/null +++ b/src/contrib/msc/core/utils.cc @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/core/utils.cc + */ + +#include "utils.h" + +#include +namespace tvm { +namespace contrib { +namespace msc { + +size_t CommonUtils::GetIndex(int index, size_t max_size) { + size_t v_index; + if (index < 0) { + v_index = index + max_size; + } else { + v_index = index; + } + ICHECK_LT(v_index, max_size) << "Index " << index << " out of range " << max_size; + return v_index; +} + +const Array StringUtils::Split(const String& src_string, const String& sep) { + Array sub_strings; + if (src_string.size() == 0) { + return sub_strings; + } + std::string src_cstring = src_string; + const std::string& csep = sep; + int pos = src_cstring.find(csep); + while (pos >= 0) { + if (pos > 0) { + sub_strings.push_back(src_cstring.substr(0, pos)); + } + src_cstring = src_cstring.substr(pos + csep.size()); + pos = src_cstring.find(csep); + } + if (src_cstring.size() > 0) { + sub_strings.push_back(src_cstring); + } + return sub_strings; +} + +const String StringUtils::Join(const Array& sub_strings, const String& joint) { + String join_str = ""; + for (size_t i = 0; i < sub_strings.size(); i++) { + join_str = join_str + sub_strings[i] + (i == sub_strings.size() - 1 ? "" : joint); + } + return join_str; +} + +const String StringUtils::Replace(const String& src_string, const String& old_str, + const String& new_str) { + String new_string; + const auto& sub_strings = Split(src_string, old_str); + for (size_t i = 0; i < sub_strings.size(); i++) { + new_string = new_string + sub_strings[i] + (i == sub_strings.size() - 1 ? "" : new_str); + } + return new_string; +} + +const std::tuple StringUtils::SplitOnce(const String& src_string, const String& sep, + bool from_left) { + if (src_string.size() == 0) { + return std::make_tuple(String(), String()); + } + std::string src_cstring = src_string; + const std::string& csep = sep; + int pos = from_left ? src_cstring.find(csep) : src_cstring.rfind(csep); + if (pos >= 0) { + return std::make_tuple(src_cstring.substr(0, pos), src_cstring.substr(pos + csep.size())); + } + return std::make_tuple(src_string, String()); +} + +const Array StringUtils::GetClosures(const String& src_string, const String& left, + const String& right) { + Array tokens; + if (src_string.size() == 0) { + return tokens; + } + String token = "start"; + String left_str = src_string; + while (token.size() > 0) { + std::tie(token, left_str) = StringUtils::SplitOnce(left_str, left); + if (left_str.size() > 0) { + std::tie(token, left_str) = StringUtils::SplitOnce(left_str, right); + } else { + token = ""; + } + if (token.size() > 0) { + tokens.push_back(token); + } + } + return tokens; +} + +const String StringUtils::GetClosureOnce(const String& src_string, const String& left, + const String& right, bool from_left) { + if (src_string.size() == 0) { + return ""; + } + String val = std::get<1>(SplitOnce(src_string, left, from_left)); + if (val.size() > 0) { + val = std::get<0>(StringUtils::SplitOnce(val, right, from_left)); + } + return val; +} + +const String StringUtils::ToString(const runtime::ObjectRef& obj) { + String obj_string; + if (!obj.defined()) { + obj_string = ""; + } else if (obj.as()) { + obj_string = Downcast(obj); + } else if (const auto* n = obj.as()) { + obj_string = std::to_string(n->value); + } else if (const auto* n = obj.as()) { + obj_string = std::to_string(n->value); + } else if (const auto* n = obj.as()) { + for (size_t i = 0; i < n->size(); i++) { + obj_string = obj_string + ToString((*n)[i]); + if (n->size() == 1 || i < n->size() - 1) { + obj_string = obj_string + ","; + } + } + } else { + std::ostringstream obj_des; + obj_des << obj; + obj_string = obj_des.str(); + } + return obj_string; +} + +bool StringUtils::CompareArrays(const Array& left, const Array& right, int size) { + if (left.size() == right.size() == 0) { + return true; + } + if (size == -1 && left.size() != right.size()) { + return false; + } + if (left.size() == 0 || right.size() == 0) { + return false; + } + size = left.size(); + ICHECK_GT(size, 0) << "Positive size should be given, get " << size; + if (size > left.size() || size > right.size()) { + return false; + } + for (size_t i = 0; i < size; i++) { + if (left[i] != right[i]) { + return false; + } + } + return true; +} + +const Span SpanUtils::SetAttr(const Span& span, const String& key, const String& value) { + if (value.size() == 0) { + return span; + } + String new_source; + Array tokens{"<" + key + ">", ""}; + if (span.defined() && span->source_name.defined()) { + const String& source_str = span->source_name->name; + String left = std::get<0>(StringUtils::SplitOnce(source_str, tokens[0])); + String right = std::get<1>(StringUtils::SplitOnce(source_str, tokens[1])); + if (left.size() > 0) { + new_source = left + tokens[0] + value + tokens[1] + right; + } else { + new_source = source_str + tokens[0] + value + tokens[1]; + } + } else { + new_source = tokens[0] + value + tokens[1]; + } + if (span.defined()) { + return Span(SourceName::Get(new_source), span->line, span->end_line, span->column, + span->end_column); + } + return Span(SourceName::Get(new_source), 0, 0, 0, 0); +} + +const String SpanUtils::GetAttr(const Span& span, const String& key) { + if (span.defined() && span->source_name.defined()) { + Array tokens{"<" + key + ">", ""}; + return StringUtils::GetClosureOnce(span->source_name->name, tokens[0], tokens[1]); + } + return ""; +} + +const Map SpanUtils::GetAttrs(const Span& span) { + Map attrs; + for (const auto& key : StringUtils::GetClosures(span->source_name->name, "")) { + attrs.Set(key, GetAttr(span, key)); + } + return attrs; +} + +const Array ExprUtils::GetInputTypes(const String& optype, size_t inputs_num, + bool as_relax) { + Array input_types; + if (as_relax && (optype == "broadcast_to" || optype == "reshape")) { + input_types.push_back("input"); + input_types.push_back("shape"); + } else if (optype == "clip" && as_relax) { + input_types.push_back("input"); + input_types.push_back("min"); + input_types.push_back("max"); + } else if (optype == "full" && as_relax) { + input_types.push_back("shape"); + input_types.push_back("input"); + } else if (optype == "trilu") { + input_types.push_back("input"); + input_types.push_back("k"); + } else if (optype == "image.resize2d" && as_relax) { + input_types.push_back("input"); + input_types.push_back("size"); + } else if (optype == "nn.conv1d" || optype == "nn.conv2d" || optype == "nn.conv3d") { + input_types.push_back("input"); + input_types.push_back("weight"); + } else if (optype == "nn.batch_norm") { + input_types.push_back("input"); + input_types.push_back("gamma"); + input_types.push_back("beta"); + input_types.push_back("mean"); + input_types.push_back("var"); + } else if (optype == "nn.layer_norm" || optype == "nn.group_norm") { + input_types.push_back("input"); + input_types.push_back("gamma"); + input_types.push_back("beta"); + } else if (optype == "msc.linear") { + if (as_relax) { + input_types.push_back("weight"); + input_types.push_back("input"); + } else { + input_types.push_back("input"); + input_types.push_back("weight"); + } + } else if (optype == "msc.conv1d_bias" || optype == "msc.conv2d_bias") { + input_types.push_back("input"); + input_types.push_back("weight"); + input_types.push_back("bias"); + if (as_relax) { + input_types.push_back("expand_bias"); + } + } else if (optype == "msc.linear_bias") { + if (as_relax) { + input_types.push_back("weight"); + input_types.push_back("input"); + } else { + input_types.push_back("input"); + input_types.push_back("weight"); + } + input_types.push_back("bias"); + } else if (optype == "msc.embedding" && inputs_num == 2) { + input_types.push_back("input"); + input_types.push_back("weight"); + } else if (optype == "msc.embedding" && inputs_num == 4) { + input_types.push_back("input"); + input_types.push_back("reduce_in"); + input_types.push_back("weight"); + input_types.push_back("expand_out"); + } else if (optype == "msc.gelu") { + input_types.push_back("input"); + input_types.push_back("factor_1"); + input_types.push_back("factor_2"); + input_types.push_back("factor_3"); + } else { + for (size_t i = 0; i < inputs_num; i++) { + input_types.push_back("input"); + } + } + ICHECK_EQ(input_types.size(), inputs_num) + << "Optype " << optype << " get input types " << input_types << " and inputs_num " + << inputs_num << " mismatch"; + return input_types; +} + +const Array ExprUtils::GetInputTypes(const RelaxCall& call) { + const String& optype = StringUtils::Replace(Downcast(call->op)->name, "relax.", ""); + return GetInputTypes(optype, call->args.size(), true); +} + +const Array ExprUtils::GetInputTypes(const RelayCall& call) { + const String& optype = StringUtils::Replace(Downcast(call->op)->name, "relay.", ""); + return GetInputTypes(optype, call->args.size(), false); +} + +TVM_REGISTER_GLOBAL("msc.core.SpanGetAttr").set_body_typed(SpanUtils::GetAttr); + +TVM_REGISTER_GLOBAL("msc.core.SpanGetAttrs").set_body_typed(SpanUtils::GetAttrs); + +} // namespace msc +} // namespace contrib +} // namespace tvm diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h new file mode 100644 index 000000000000..9da4ce3346f9 --- /dev/null +++ b/src/contrib/msc/core/utils.h @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/core/utils.h + * \brief Common utilities for msc. + */ +#ifndef TVM_CONTRIB_MSC_CORE_UTILS_H_ +#define TVM_CONTRIB_MSC_CORE_UTILS_H_ + +#include +#include +#include + +#include +#include + +namespace tvm { +namespace contrib { +namespace msc { + +using Expr = tvm::RelayExpr; +using RelaxCall = tvm::relax::Call; +using RelayCall = tvm::relay::Call; + +class CommonUtils { + public: + /*! + * \brief Check if the index is in range. + * \return The valid index. + */ + TVM_DLL static size_t GetIndex(int index, size_t max_size); +}; + +/*! + * \brief Utils for String. + */ +class StringUtils { + public: + /*! + * \brief Split the String into sub Strings. + * \return The SubStrings. + */ + TVM_DLL static const Array Split(const String& src_string, const String& sep); + + /*! + * \brief Join the SubStrings into String. + * \return The String. + */ + TVM_DLL static const String Join(const Array& sub_strings, const String& joint); + + /*! + * \brief Replace the substring old to new in String. + * \return The replaced String. + */ + TVM_DLL static const String Replace(const String& src_string, const String& old_str, + const String& new_str); + + /*! + * \brief Split the String into two sub Strings, only split by the frist seq. + * \return The SubStrings. + */ + TVM_DLL static const std::tuple SplitOnce(const String& src_string, + const String& sep, + bool from_left = true); + + /*! + * \brief Get the tokens between left and right. + * \return The Tokens. + */ + TVM_DLL static const Array GetClosures(const String& src_string, const String& left, + const String& right); + + /*! + * \brief Get the first token between left and right. + * \return The Token. + */ + TVM_DLL static const String GetClosureOnce(const String& src_string, const String& left, + const String& right, bool from_left = true); + + /*! + * \brief Change Object to String. + * \return The String. + */ + TVM_DLL static const String ToString(const runtime::ObjectRef& obj); + + /*! + * \brief Compare String arrays. + * \return Whether two array are same. + */ + TVM_DLL static bool CompareArrays(const Array& left, const Array& right, + int size = -1); +}; + +/*! + * \brief Utils for Array. + */ +class ArrayUtils { + public: + /*! + * \brief Replace the element old to new in Array. + * \return The replaced Array. + */ + template + TVM_DLL static const Array Replace(const Array& src_array, const T& old_ele, + const T& new_ele) { + Array new_array; + for (const auto& a : src_array) { + if (a == old_ele) { + new_array.push_back(new_ele); + } else { + new_array.push_back(a); + } + } + return new_array; + } + + /*! + * \brief Find the index of element. + * \return The index, -1 if not found. + */ + template + TVM_DLL static int IndexOf(const std::vector& array, const T& ele) { + for (size_t i = 0; i < array.size(); i++) { + if (array[i] == ele) { + return i; + } + } + return -1; + } + + /*! + * \brief Downcast elements in the array. + * \return The downcasted array + */ + template + TVM_DLL static const Array Cast(const Array& src_array) { + Array new_array; + for (const auto& s : src_array) { + new_array.push_back(Downcast(s)); + } + return new_array; + } +}; + +/*! + * \brief Utils for Span. + */ +class SpanUtils { + public: + /*! + * \brief Set value to the Span. + * \return The new Span. + */ + TVM_DLL static const Span SetAttr(const Span& span, const String& key, const String& value); + + /*! + * \brief Get the value in value from the Span. + * \return The value String. + */ + TVM_DLL static const String GetAttr(const Span& span, const String& key); + + /*! + * \brief Get all the key:value in format value from the Span. + * \return The Attrs Map. + */ + TVM_DLL static const Map GetAttrs(const Span& span); +}; + +/*! + * \brief Utils for Expr. + */ +class ExprUtils { + public: + /*! + * \brief Get the input types of call. + * \return The input types. + */ + TVM_DLL static const Array GetInputTypes(const String& optype, size_t inputs_num, + bool as_relax); + + /*! + * \brief Get the input types of call. + * \return The input types. + */ + TVM_DLL static const Array GetInputTypes(const RelaxCall& call); + + /*! + * \brief Get the input types of call. + * \return The input types. + */ + TVM_DLL static const Array GetInputTypes(const RelayCall& call); + + /*! + * \brief Get the scalar value of ndarray. + * \return The scalar value. + */ + template + TVM_DLL static const T GetScalar(const runtime::NDArray& array, size_t i = 0) { + if (array->dtype.code == kDLInt) { + if (array->dtype.bits == 8) { + return T(reinterpret_cast(array->data)[i]); + } else if (array->dtype.bits == 16) { + return T(reinterpret_cast(array->data)[i]); + } else if (array->dtype.bits == 32) { + return T(reinterpret_cast(array->data)[i]); + } else if (array->dtype.bits == 64) { + return T(reinterpret_cast(array->data)[i]); + } + } else if (array->dtype.code == kDLUInt) { + if (array->dtype.bits == 1) { // bool + return T(reinterpret_cast(array->data)[i]); + } else if (array->dtype.bits == 8) { + return T(reinterpret_cast(array->data)[i]); + } else if (array->dtype.bits == 16) { + return T(reinterpret_cast(array->data)[i]); + } else if (array->dtype.bits == 32) { + return T(reinterpret_cast(array->data)[i]); + } else if (array->dtype.bits == 64) { + return T(reinterpret_cast(array->data)[i]); + } + } else if (array->dtype.code == kDLFloat) { + if (array->dtype.bits == 32) { + return T(reinterpret_cast(array->data)[i]); + } else if (array->dtype.bits == 64) { + return T(reinterpret_cast(array->data)[i]); + } + } + LOG(FATAL) << "Failed to get scalar from array " << array; + } + + /*! + * \brief Get the scalar value of relax constant. + * \return The scalar value. + */ + template + TVM_DLL static const T GetScalar(const relax::Constant& constant, size_t i = 0) { + return GetScalar(constant->data, i); + } + + /*! + * \brief Get the scalar value of relay constant. + * \return The scalar value. + */ + template + TVM_DLL static const T GetScalar(const relay::Constant& constant, size_t i = 0) { + return GetScalar(constant->data, i); + } +}; + +} // namespace msc +} // namespace contrib +} // namespace tvm +#endif // TVM_CONTRIB_MSC_CORE_UTILS_H_ diff --git a/tests/python/contrib/test_msc/test_transform_set_expr_name.py b/tests/python/contrib/test_msc/test_transform_set_expr_name.py new file mode 100644 index 000000000000..205c576b500c --- /dev/null +++ b/tests/python/contrib/test_msc/test_transform_set_expr_name.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm.testing +from tvm.relay import testing +from tvm.relay.expr_functor import ExprVisitor + +from tvm.relax.testing import nn +from tvm.relax import PyExprVisitor + +from tvm.contrib.msc.core import _ffi_api +from tvm.contrib.msc.core import transform as msc_transform + + +class RelayChecker(ExprVisitor): + """Check if name as span attribute is setted.""" + + def check(self, expr): + self._missing_exprs = [] + super.visit(expr) + assert len(self._missing_exprs) == 0, "Missing {} names".format(len(self._missing_exprs)) + + def visit(self, expr): + super().visit(expr) + name = _ffi_api.SpanGetAttr(expr.span, "name") + if not name: + self._missing_exprs.append(expr) + + +class RelaxChecker(PyExprVisitor): + """Check if name as span attribute is setted.""" + + def check(self, expr): + self._missing_exprs = [] + super.visit(expr) + assert len(self._missing_exprs) == 0, "Missing {} names".format(len(self._missing_exprs)) + + def visit_binding(self, binding): + super().visit_binding(binding) + name = _ffi_api.SpanGetAttr(binding.value.span, "name") + if not name: + self._missing_exprs.append(binding.value) + + def visit_constant_(self, op): + super().visit_constant_(op) + name = _ffi_api.SpanGetAttr(op.span, "name") + if not name: + self._missing_exprs.append(op) + + +def test_relay(): + mod, _ = testing.resnet.get_workload(num_layers=50, batch_size=1, dtype="float32") + mod = msc_transform.SetExprName(as_relax=False)(mod) + print("mod " + str(mod)) + RelayChecker().check(mod["main"]) + + +def test_relax(): + builder = tvm.relax.BlockBuilder() + + # a symbolic variable to represent minibatch size + n = tvm.tir.Var("n", "int64") + input_size = 784 + hidden_sizes = [128, 32] + output_size = 10 + + # build a three linear-layer neural network for a classification task + with builder.function("main"): + model = nn.Sequential( + nn.Linear(input_size, hidden_sizes[0]), + nn.ReLU(), + nn.Linear(hidden_sizes[0], hidden_sizes[1]), + nn.ReLU(), + nn.Linear(hidden_sizes[1], output_size), + nn.LogSoftmax(), + ) + data = nn.Placeholder((n, input_size), name="data") + output = model(data) + params = [data] + model.parameters() + builder.emit_func_output(output, params=params) + + # get and print the IRmodule being built + mod = builder.get() + mod = msc_transform.SetExprName()(mod) + print("mod " + str(mod)) + RelaxChecker().check(mod["main"]) + + +if __name__ == "__main__": + # tvm.testing.main() + test_relay() From 75d44e8bc2cda4b1f7fe66e42c28b8f6eb920765 Mon Sep 17 00:00:00 2001 From: Archermmt Date: Mon, 7 Aug 2023 10:43:27 +0800 Subject: [PATCH 2/8] roll back to M0.1 --- cmake/modules/LibInfo.cmake | 1 + src/support/libinfo.cc | 1 + .../test_transform_set_expr_layout.py | 73 ++++++++++++++++++ .../test_msc/test_transform_set_expr_name.py | 74 +++++++++---------- 4 files changed, 110 insertions(+), 39 deletions(-) create mode 100644 tests/python/contrib/test_msc/test_transform_set_expr_layout.py diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index 17a56ac439e2..9e1f71c72938 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -125,6 +125,7 @@ function(add_lib_info src_file) TVM_INFO_USE_TVM_CLML_VERSION="${CLML_VERSION_MAJOR}" TVM_INFO_USE_UMA="${USE_UMA}" TVM_INFO_USE_VERILATOR="${USE_VERILATOR}" + TVM_INFO_USE_MSC="${USE_MSC}" TVM_INFO_USE_CCACHE="${USE_CCACHE}" TVM_INFO_BACKTRACE_ON_SEGFAULT="${BACKTRACE_ON_SEGFAULT}" ) diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index 23091fffd23a..3f028ba65657 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -342,6 +342,7 @@ TVM_DLL Map GetLibInfo() { {"USE_CLML_GRAPH_EXECUTOR", TVM_INFO_USE_CLML_GRAPH_EXECUTOR}, {"USE_UMA", TVM_INFO_USE_UMA}, {"USE_VERILATOR", TVM_INFO_USE_VERILATOR}, + {"USE_MSC", TVM_INFO_USE_MSC}, {"USE_CCACHE", TVM_INFO_USE_CCACHE}, {"BACKTRACE_ON_SEGFAULT", TVM_INFO_BACKTRACE_ON_SEGFAULT}, }; diff --git a/tests/python/contrib/test_msc/test_transform_set_expr_layout.py b/tests/python/contrib/test_msc/test_transform_set_expr_layout.py new file mode 100644 index 000000000000..4717437d7662 --- /dev/null +++ b/tests/python/contrib/test_msc/test_transform_set_expr_layout.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm.testing +from tvm.relay import testing +from tvm.relay.expr_functor import ExprVisitor +from tvm.relay.build_module import bind_params_by_name + +from tvm.relax.frontend.torch import from_fx +from tvm.relax import PyExprVisitor + +from tvm.contrib.msc.core import _ffi_api +from tvm.contrib.msc.core import transform as msc_transform + + +class RelaxChecker(PyExprVisitor): + """Check if name as span attribute is setted.""" + + def check(self, expr): + self._missing_exprs = [] + if isinstance(expr, tvm.relax.Expr): + self.visit_expr(expr) + elif isinstance(expr, tvm.relax.BindingBlock): + self.visit_binding_block(expr) + assert len(self._missing_exprs) == 0, "Missing {} layouts".format(len(self._missing_exprs)) + + def visit_var_binding_(self, binding) -> None: + super().visit_var_binding_(binding) + layout = _ffi_api.SpanGetAttr(binding.value.span, "layout") + if not layout: + self._missing_exprs.append(binding.value) + + def visit_constant_(self, op) -> None: + super().visit_constant_(op) + layout = _ffi_api.SpanGetAttr(op.span, "layout") + if not layout: + self._missing_exprs.append(op) + + +def test_relax(): + try: + import torch + import torchvision + from torch import fx + except: + print("please install pytorch python package") + return + + torch_model = torchvision.models.resnet50() + graph_model = fx.symbolic_trace(torch_model) + input_info = [([1, 3, 224, 224], "float32")] + with torch.no_grad(): + mod = from_fx(graph_model, input_info) + mod = msc_transform.SetExprLayout()(mod) + RelaxChecker().check(mod) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/contrib/test_msc/test_transform_set_expr_name.py b/tests/python/contrib/test_msc/test_transform_set_expr_name.py index 205c576b500c..0c174ff7bd01 100644 --- a/tests/python/contrib/test_msc/test_transform_set_expr_name.py +++ b/tests/python/contrib/test_msc/test_transform_set_expr_name.py @@ -18,8 +18,9 @@ import tvm.testing from tvm.relay import testing from tvm.relay.expr_functor import ExprVisitor +from tvm.relay.build_module import bind_params_by_name -from tvm.relax.testing import nn +from tvm.relax.frontend.torch import from_fx from tvm.relax import PyExprVisitor from tvm.contrib.msc.core import _ffi_api @@ -31,11 +32,17 @@ class RelayChecker(ExprVisitor): def check(self, expr): self._missing_exprs = [] - super.visit(expr) + super().visit(expr) assert len(self._missing_exprs) == 0, "Missing {} names".format(len(self._missing_exprs)) - def visit(self, expr): - super().visit(expr) + def visit_constant(self, expr): + super().visit_constant(expr) + name = _ffi_api.SpanGetAttr(expr.span, "name") + if not name: + self._missing_exprs.append(expr) + + def visit_call(self, expr): + super().visit_call(expr) name = _ffi_api.SpanGetAttr(expr.span, "name") if not name: self._missing_exprs.append(expr) @@ -46,16 +53,19 @@ class RelaxChecker(PyExprVisitor): def check(self, expr): self._missing_exprs = [] - super.visit(expr) + if isinstance(expr, tvm.relax.Expr): + self.visit_expr(expr) + elif isinstance(expr, tvm.relax.BindingBlock): + self.visit_binding_block(expr) assert len(self._missing_exprs) == 0, "Missing {} names".format(len(self._missing_exprs)) - def visit_binding(self, binding): - super().visit_binding(binding) + def visit_var_binding_(self, binding) -> None: + super().visit_var_binding_(binding) name = _ffi_api.SpanGetAttr(binding.value.span, "name") if not name: self._missing_exprs.append(binding.value) - def visit_constant_(self, op): + def visit_constant_(self, op) -> None: super().visit_constant_(op) name = _ffi_api.SpanGetAttr(op.span, "name") if not name: @@ -63,43 +73,29 @@ def visit_constant_(self, op): def test_relay(): - mod, _ = testing.resnet.get_workload(num_layers=50, batch_size=1, dtype="float32") + mod, params = testing.resnet.get_workload(num_layers=50, batch_size=1, dtype="float32") + mod["main"] = bind_params_by_name(mod["main"], params) mod = msc_transform.SetExprName(as_relax=False)(mod) - print("mod " + str(mod)) RelayChecker().check(mod["main"]) def test_relax(): - builder = tvm.relax.BlockBuilder() - - # a symbolic variable to represent minibatch size - n = tvm.tir.Var("n", "int64") - input_size = 784 - hidden_sizes = [128, 32] - output_size = 10 - - # build a three linear-layer neural network for a classification task - with builder.function("main"): - model = nn.Sequential( - nn.Linear(input_size, hidden_sizes[0]), - nn.ReLU(), - nn.Linear(hidden_sizes[0], hidden_sizes[1]), - nn.ReLU(), - nn.Linear(hidden_sizes[1], output_size), - nn.LogSoftmax(), - ) - data = nn.Placeholder((n, input_size), name="data") - output = model(data) - params = [data] + model.parameters() - builder.emit_func_output(output, params=params) - - # get and print the IRmodule being built - mod = builder.get() + try: + import torch + import torchvision + from torch import fx + except: + print("please install pytorch python package") + return + + torch_model = torchvision.models.resnet50() + graph_model = fx.symbolic_trace(torch_model) + input_info = [([1, 3, 224, 224], "float32")] + with torch.no_grad(): + mod = from_fx(graph_model, input_info) mod = msc_transform.SetExprName()(mod) - print("mod " + str(mod)) - RelaxChecker().check(mod["main"]) + RelaxChecker().check(mod) if __name__ == "__main__": - # tvm.testing.main() - test_relay() + tvm.testing.main() From 0d3387950807eaea7f7cbe1f2a55e87277e9b85d Mon Sep 17 00:00:00 2001 From: Archermmt Date: Wed, 9 Aug 2023 06:18:11 +0800 Subject: [PATCH 3/8] add annotation --- .../tvm/contrib/msc/core/transform/pattern.py | 224 ++++++++++++++---- 1 file changed, 179 insertions(+), 45 deletions(-) diff --git a/python/tvm/contrib/msc/core/transform/pattern.py b/python/tvm/contrib/msc/core/transform/pattern.py index 500870509791..311889f2ec54 100644 --- a/python/tvm/contrib/msc/core/transform/pattern.py +++ b/python/tvm/contrib/msc/core/transform/pattern.py @@ -17,6 +17,8 @@ # pylint: disable=unused-argument """tvm.contrib.msc.core.transform.pattern""" +from typing import Mapping, Tuple + import tvm from tvm.relax.dpl import pattern as relax_pattern from tvm.relay import dataflow_pattern as relay_pattern @@ -26,8 +28,10 @@ from tvm.relay.op.contrib.register import register_pattern_table -def make_relax_conv_bias_pattern(op_name): - """A simple utility to create patterns for an operation fused with bias. +def make_relax_conv_bias_pattern( + op_name: str, +) -> Tuple[relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern]]: + """A simple utility to create patterns for an conv fused with bias. Parameters ---------- @@ -36,8 +40,13 @@ def make_relax_conv_bias_pattern(op_name): Returns ------- - pattern: DFPattern - The resulting pattern describing a conv_bias operation + out: tvm.relax.dpl.pattern.DFPattern + The resulting pattern describing a conv_bias operation. + + annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] + A mapping from name to sub pattern. It can be used to extract + important expressions from match result, to power the partition + check function and codegen. """ data = relax_pattern.wildcard() @@ -52,20 +61,34 @@ def make_relax_conv_bias_pattern(op_name): def _check_relax_conv_bias(context: PatternCheckContext) -> bool: - """Check if conv_bias fuse pattern is correct.""" + """Check if conv_bias fuse pattern is correct. + + Returns + ------- + pass: bool + Whether the pattern is correct. + """ + bias = context.annotated_expr["bias"] reshape = context.annotated_expr["reshape"] non_one_dims = len([i for i in reshape.struct_info.shape.values if i > 1]) return non_one_dims <= 1 and bias.struct_info.ndim == 1 -def make_relax_linear_pattern(): +def make_relax_linear_pattern() -> Tuple[ + relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern] +]: """A simple utility to create patterns for linear. Returns ------- - pattern: DFPattern - The resulting pattern describing a linear operation + out: tvm.relax.dpl.pattern.DFPattern + The resulting pattern describing a linear operation. + + annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] + A mapping from name to sub pattern. It can be used to extract + important expressions from match result, to power the partition + check function and codegen. """ data = relax_pattern.wildcard() @@ -77,19 +100,34 @@ def make_relax_linear_pattern(): def _check_relax_linear(context: PatternCheckContext) -> bool: - """Check if linear pattern is correct.""" + """Check if linear pattern is correct. + + Returns + ------- + pass: bool + Whether the pattern is correct. + """ + weight = context.annotated_expr["weight"] permute = context.annotated_expr["permute"] return weight.struct_info.ndim == 2 and not permute.attrs["axes"] -def make_relax_linear_bias_pattern(): +def make_relax_linear_bias_pattern() -> Tuple[ + relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern] +]: """A simple utility to create patterns for linear with bias. Returns ------- - pattern: DFPattern - The resulting pattern describing a linear_bias operation + out: tvm.relax.dpl.pattern.DFPattern + The resulting pattern describing a linear_bias operation. + + annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] + A mapping from name to sub pattern. It can be used to extract + important expressions from match result, to power the partition + check function and codegen. + """ linear, annotations = make_relax_linear_pattern() @@ -100,20 +138,34 @@ def make_relax_linear_bias_pattern(): def _check_relax_linear_bias(context: PatternCheckContext) -> bool: - """Check if linear_bias pattern is correct.""" + """Check if linear_bias pattern is correct. + + Returns + ------- + pass: bool + Whether the pattern is correct. + """ + if not _check_relax_linear(context): return False bias = context.annotated_expr["bias"] return bias.struct_info.ndim == 1 -def make_relax_embedding_pattern(): +def make_relax_embedding_pattern() -> Tuple[ + relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern] +]: """A simple utility to create patterns for embedding. Returns ------- - pattern: DFPattern - The resulting pattern describing a embedding operation + out: tvm.relax.dpl.pattern.DFPattern + The resulting pattern describing a embedding operation. + + annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] + A mapping from name to sub pattern. It can be used to extract + important expressions from match result, to power the partition + check function and codegen. """ weight = relax_pattern.is_const() @@ -125,7 +177,14 @@ def make_relax_embedding_pattern(): def _check_relax_embedding(context: PatternCheckContext) -> bool: - """Check if 1d embedding pattern is correct.""" + """Check if 1d embedding pattern is correct. + + Returns + ------- + pass: bool + Whether the pattern is correct. + """ + weight = context.annotated_expr["weight"] astype = context.annotated_expr["astype"] return ( @@ -135,13 +194,20 @@ def _check_relax_embedding(context: PatternCheckContext) -> bool: ) -def make_relax_reshape_embedding_pattern(): +def make_relax_reshape_embedding_pattern() -> Tuple[ + relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern] +]: """A simple utility to create patterns for reshaped embedding. Returns ------- - pattern: DFPattern - The resulting pattern describing a reshaped rembedding operation + out: tvm.relax.dpl.pattern.DFPattern + The resulting pattern describing a reshape_embedding operation. + + annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] + A mapping from name to sub pattern. It can be used to extract + important expressions from match result, to power the partition + check function and codegen. """ weight = relax_pattern.is_const() @@ -157,7 +223,14 @@ def make_relax_reshape_embedding_pattern(): def _check_relax_reshape_embedding(context: PatternCheckContext) -> bool: - """Check if reshape embedding pattern is correct.""" + """Check if reshape embedding pattern is correct. + + Returns + ------- + pass: bool + Whether the pattern is correct. + """ + weight = context.annotated_expr["weight"] if weight.struct_info.ndim != 2 or weight.struct_info.dtype != "float32": return False @@ -168,13 +241,20 @@ def _check_relax_reshape_embedding(context: PatternCheckContext) -> bool: return True -def make_relax_attention_pattern(): +def make_relax_attention_pattern() -> Tuple[ + relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern] +]: """A simple utility to create patterns for attention. Returns ------- - pattern: DFPattern - The resulting pattern describing a attention operation + out: tvm.relax.dpl.pattern.DFPattern + The resulting pattern describing a attention operation. + + annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] + A mapping from name to sub pattern. It can be used to extract + important expressions from match result, to power the partition + check function and codegen. """ weight_q = relax_pattern.wildcard() @@ -189,17 +269,31 @@ def make_relax_attention_pattern(): def _check_relax_attention(context: PatternCheckContext) -> bool: - """Check if attention pattern is correct.""" + """Check if attention pattern is correct. + + Returns + ------- + pass: bool + Whether the pattern is correct. + """ + return True -def make_relax_mask_attention_pattern(): +def make_relax_mask_attention_pattern() -> Tuple[ + relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern] +]: """A simple utility to create patterns for mask_attention. Returns ------- - pattern: DFPattern - The resulting pattern describing a mask_attention operation + out: tvm.relax.dpl.pattern.DFPattern + The resulting pattern describing a mask_attention operation. + + annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] + A mapping from name to sub pattern. It can be used to extract + important expressions from match result, to power the partition + check function and codegen. """ weight_q = relax_pattern.wildcard() @@ -215,7 +309,14 @@ def make_relax_mask_attention_pattern(): def _check_relax_mask_attention(context: PatternCheckContext) -> bool: - """Check if mask_attention pattern is correct.""" + """Check if mask_attention pattern is correct. + + Returns + ------- + pass: bool + Whether the pattern is correct. + """ + return True @@ -276,7 +377,7 @@ def pattern_table(): """Returns list of triples describing the name, dataflow pattern and predicate for all the MSC-supported operators.""" - def make_relay_conv_bias_pattern(op_name, optimized=False): + def make_relay_conv_bias_pattern(op_name, optimized=False) -> relay_pattern.DFPattern: """A simple utility to create patterns for an operation fused with bias. Parameters @@ -288,7 +389,7 @@ def make_relay_conv_bias_pattern(op_name, optimized=False): Returns ------- - pattern: DFPattern + pattern: tvm.relay.dataflow_pattern.DFPattern The resulting pattern describing a conv_bias operation """ @@ -303,7 +404,13 @@ def make_relay_conv_bias_pattern(op_name, optimized=False): return out def _check_relay_conv_bias(call: tvm.relay.Expr) -> bool: - """Check if conv_bias fuse pattern is correct.""" + """Check if conv_bias fuse pattern is correct. + + Returns + ------- + pass: bool + Whether the pattern is correct. + """ if call.op.name == "nn.bias_add": bias = call.args[1] @@ -312,7 +419,7 @@ def _check_relay_conv_bias(call: tvm.relay.Expr) -> bool: return True return False - def make_relay_linear_pattern(optimized=False): + def make_relay_linear_pattern(optimized=False) -> relay_pattern.DFPattern: """A simple utility to create patterns for linear. Parameters @@ -322,7 +429,7 @@ def make_relay_linear_pattern(optimized=False): Returns ------- - pattern: DFPattern + pattern: tvm.relay.dataflow_pattern.DFPattern The resulting pattern describing a linear operation """ @@ -346,10 +453,17 @@ def make_relay_linear_pattern(optimized=False): return relay_pattern.is_op("squeeze")(reshape_out) def _check_relay_linear(call: tvm.relay.Expr) -> bool: - """Check if linear pattern is correct.""" + """Check if linear pattern is correct. + + Returns + ------- + pass: bool + Whether the pattern is correct. + """ + return True - def make_relay_linear_bias_pattern(optimized=False): + def make_relay_linear_bias_pattern(optimized=False) -> relay_pattern.DFPattern: """A simple utility to create patterns for linear_bias. Parameters @@ -375,7 +489,7 @@ def _check_relay_linear_bias(call: tvm.relay.Expr) -> bool: """Check if linear_bias pattern is correct.""" return True - def make_relay_matmul_pattern(dim=2, optimized=False): + def make_relay_matmul_pattern(dim=2, optimized=False) -> relay_pattern.DFPattern: """A simple utility to create patterns for matmul. Parameters @@ -385,7 +499,7 @@ def make_relay_matmul_pattern(dim=2, optimized=False): Returns ------- - pattern: DFPattern + pattern: tvm.relay.dataflow_pattern.DFPattern The resulting pattern describing a matmul operation """ @@ -409,7 +523,14 @@ def make_relay_matmul_pattern(dim=2, optimized=False): raise Exception("matmul pattern only support dim 2 and 3") def _check_relay_matmul(call: tvm.relay.Expr) -> bool: - """Check if matmul pattern is correct.""" + """Check if matmul pattern is correct. + + Returns + ------- + pass: bool + Whether the pattern is correct. + """ + last_call = call.args[0] if call.op.name == "squeeze" else call if last_call.op.name == "nn.dense": trans_b = last_call.args[1] @@ -419,12 +540,12 @@ def _check_relay_matmul(call: tvm.relay.Expr) -> bool: return trans_b.attrs["axes"] is None or list(trans_b.attrs["axes"]) == [1, 0] return True - def make_relay_embedding_pattern(optimized=False): + def make_relay_embedding_pattern(optimized=False) -> relay_pattern.DFPattern: """A simple utility to create patterns for 1d embedding. Returns ------- - pattern: DFPattern + pattern: tvm.relay.dataflow_pattern.DFPattern The resulting pattern describing a embedding operation """ @@ -434,7 +555,13 @@ def make_relay_embedding_pattern(optimized=False): return relay_pattern.is_op("take")(weight, astype) def _check_relay_embedding(call) -> bool: - """Check if embedding pattern is correct.""" + """Check if embedding pattern is correct. + + Returns + ------- + pass: bool + Whether the pattern is correct. + """ weight = call.args[0] cast = call.args[1] @@ -444,12 +571,12 @@ def _check_relay_embedding(call) -> bool: and weight.checked_type.dtype == "float32" ) - def make_relay_gelu_pattern(optimized=False): + def make_relay_gelu_pattern(optimized=False) -> relay_pattern.DFPattern: """A simple utility to create patterns for gelu. Returns ------- - pattern: DFPattern + pattern: tvm.relay.dataflow_pattern.DFPattern The resulting pattern describing a gelu operation. """ @@ -464,7 +591,14 @@ def make_relay_gelu_pattern(optimized=False): return relay_pattern.is_op("multiply")(data, add) def _check_relay_gelu(call) -> bool: - """Check if gelu pattern is correct.""" + """Check if gelu pattern is correct. + + Returns + ------- + pass: bool + Whether the pattern is correct. + """ + return True return [ From e272a85e4e126954429a29c499a2ce7249221c93 Mon Sep 17 00:00:00 2001 From: Archermmt Date: Wed, 9 Aug 2023 06:22:10 +0800 Subject: [PATCH 4/8] add annotation --- .../tvm/contrib/msc/core/transform/pattern.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/python/tvm/contrib/msc/core/transform/pattern.py b/python/tvm/contrib/msc/core/transform/pattern.py index 311889f2ec54..76e9651c603e 100644 --- a/python/tvm/contrib/msc/core/transform/pattern.py +++ b/python/tvm/contrib/msc/core/transform/pattern.py @@ -377,7 +377,9 @@ def pattern_table(): """Returns list of triples describing the name, dataflow pattern and predicate for all the MSC-supported operators.""" - def make_relay_conv_bias_pattern(op_name, optimized=False) -> relay_pattern.DFPattern: + def make_relay_conv_bias_pattern( + op_name: str, optimized: bool = False + ) -> relay_pattern.DFPattern: """A simple utility to create patterns for an operation fused with bias. Parameters @@ -419,7 +421,7 @@ def _check_relay_conv_bias(call: tvm.relay.Expr) -> bool: return True return False - def make_relay_linear_pattern(optimized=False) -> relay_pattern.DFPattern: + def make_relay_linear_pattern(optimized: bool = False) -> relay_pattern.DFPattern: """A simple utility to create patterns for linear. Parameters @@ -463,7 +465,7 @@ def _check_relay_linear(call: tvm.relay.Expr) -> bool: return True - def make_relay_linear_bias_pattern(optimized=False) -> relay_pattern.DFPattern: + def make_relay_linear_bias_pattern(optimized: bool = False) -> relay_pattern.DFPattern: """A simple utility to create patterns for linear_bias. Parameters @@ -489,7 +491,7 @@ def _check_relay_linear_bias(call: tvm.relay.Expr) -> bool: """Check if linear_bias pattern is correct.""" return True - def make_relay_matmul_pattern(dim=2, optimized=False) -> relay_pattern.DFPattern: + def make_relay_matmul_pattern(dim: int = 2, optimized: bool = False) -> relay_pattern.DFPattern: """A simple utility to create patterns for matmul. Parameters @@ -540,7 +542,7 @@ def _check_relay_matmul(call: tvm.relay.Expr) -> bool: return trans_b.attrs["axes"] is None or list(trans_b.attrs["axes"]) == [1, 0] return True - def make_relay_embedding_pattern(optimized=False) -> relay_pattern.DFPattern: + def make_relay_embedding_pattern(optimized: bool = False) -> relay_pattern.DFPattern: """A simple utility to create patterns for 1d embedding. Returns @@ -554,7 +556,7 @@ def make_relay_embedding_pattern(optimized=False) -> relay_pattern.DFPattern: astype = relay_pattern.is_op("cast")(data) return relay_pattern.is_op("take")(weight, astype) - def _check_relay_embedding(call) -> bool: + def _check_relay_embedding(call: tvm.relay.Expr) -> bool: """Check if embedding pattern is correct. Returns @@ -571,7 +573,7 @@ def _check_relay_embedding(call) -> bool: and weight.checked_type.dtype == "float32" ) - def make_relay_gelu_pattern(optimized=False) -> relay_pattern.DFPattern: + def make_relay_gelu_pattern(optimized: bool = False) -> relay_pattern.DFPattern: """A simple utility to create patterns for gelu. Returns @@ -590,7 +592,7 @@ def make_relay_gelu_pattern(optimized=False) -> relay_pattern.DFPattern: add = relay_pattern.is_op("add")(factor_3, mul_2) return relay_pattern.is_op("multiply")(data, add) - def _check_relay_gelu(call) -> bool: + def _check_relay_gelu(call: tvm.relay.Expr) -> bool: """Check if gelu pattern is correct. Returns From 78a8dd74093e0cc483bedddf2fa5e02502eefb95 Mon Sep 17 00:00:00 2001 From: Archermmt Date: Fri, 11 Aug 2023 06:46:26 +0800 Subject: [PATCH 5/8] change test to unity --- ci/jenkins/unity_jenkinsfile.groovy | 1 + tests/scripts/task_config_build_gpu.sh | 1 + tests/scripts/unity/task_python_msc.sh | 35 ++++++++++++++++++++++++++ 3 files changed, 37 insertions(+) create mode 100644 tests/scripts/unity/task_python_msc.sh diff --git a/ci/jenkins/unity_jenkinsfile.groovy b/ci/jenkins/unity_jenkinsfile.groovy index 99485f7c557f..13ce378385bf 100644 --- a/ci/jenkins/unity_jenkinsfile.groovy +++ b/ci/jenkins/unity_jenkinsfile.groovy @@ -318,6 +318,7 @@ stage('Build and Test') { sh "${docker_run} ${ci_gpu} ./tests/scripts/task_config_build_gpu.sh build" make("${ci_gpu}", 'build', '-j2') sh "${docker_run} ${ci_gpu} ./tests/scripts/unity/task_python_relax_gpuonly.sh" + sh "${docker_run} ${ci_gpu} ./tests/scripts/unity/task_python_msc.sh" } } }, diff --git a/tests/scripts/task_config_build_gpu.sh b/tests/scripts/task_config_build_gpu.sh index 8929ae504168..37ab0a87f13d 100755 --- a/tests/scripts/task_config_build_gpu.sh +++ b/tests/scripts/task_config_build_gpu.sh @@ -53,3 +53,4 @@ echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake echo set\(USE_PIPELINE_EXECUTOR ON\) >> config.cmake echo set\(USE_CUTLASS ON\) >> config.cmake echo set\(USE_CMSISNN ON\) >> config.cmake +echo set\(USE_MSC ON\) >> config.cmake diff --git a/tests/scripts/unity/task_python_msc.sh b/tests/scripts/unity/task_python_msc.sh new file mode 100644 index 000000000000..28b072607645 --- /dev/null +++ b/tests/scripts/unity/task_python_msc.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -euxo pipefail + +source tests/scripts/setup-pytest-env.sh +# to avoid openblas threading error +export TVM_BIND_THREADS=0 +export OMP_NUM_THREADS=1 + +export TVM_TEST_TARGETS="llvm;cuda" + +find . -type f -path "*.pyc" | xargs rm -f + +# Rebuild cython +make cython3 + + +echo "Running relay MXNet frontend test..." +run_pytest cython python-contrib-msc tests/python/contrib/test_msc From a22f4a9b43cddf2737515b28a206fad6386418f6 Mon Sep 17 00:00:00 2001 From: Archermmt Date: Fri, 11 Aug 2023 07:09:21 +0800 Subject: [PATCH 6/8] remove msg --- tests/scripts/unity/task_python_msc.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/scripts/unity/task_python_msc.sh b/tests/scripts/unity/task_python_msc.sh index 28b072607645..eed1f79f83fc 100644 --- a/tests/scripts/unity/task_python_msc.sh +++ b/tests/scripts/unity/task_python_msc.sh @@ -30,6 +30,4 @@ find . -type f -path "*.pyc" | xargs rm -f # Rebuild cython make cython3 - -echo "Running relay MXNet frontend test..." run_pytest cython python-contrib-msc tests/python/contrib/test_msc From 8c441d9d0d7004960ef3d0cb6c0d07a4a189043f Mon Sep 17 00:00:00 2001 From: Archermmt Date: Fri, 11 Aug 2023 20:07:02 +0800 Subject: [PATCH 7/8] minor fix --- .../msc/core/transform/set_expr_layout.cc | 28 ++++++++++--------- src/contrib/msc/core/utils.cc | 6 ++-- tests/scripts/task_config_build_cpu.sh | 1 + 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc b/src/contrib/msc/core/transform/set_expr_layout.cc index 981829b56809..5915bef9e138 100644 --- a/src/contrib/msc/core/transform/set_expr_layout.cc +++ b/src/contrib/msc/core/transform/set_expr_layout.cc @@ -35,7 +35,7 @@ namespace relax { using namespace tvm::contrib::msc; NLayout InferNLayout(const Expr& expr, const VarLayoutMap& var_layout_map) { - if (expr.as() && var_layout_map.count(Downcast(expr))) { + if (expr->IsInstance() && var_layout_map.count(Downcast(expr))) { return GetNLayout(var_layout_map, expr); } return LayoutUtils::GetNLayout(expr); @@ -83,10 +83,12 @@ std::tuple AccumulateMatch(const std::vector& in_shap } // append tailed 1s if (in_pos >= 0) { - while (in_pos < in_shape.size() - 1 && in_shape[in_pos + 1] == 1) { + int64_t in_size = static_cast(in_shape.size()); + int64_t out_size = static_cast(out_shape.size()); + while (in_pos < in_size - 1 && in_shape[in_pos + 1] == 1) { in_pos++; } - while (out_pos < out_shape.size() - 1 && out_shape[out_pos + 1] == 1) { + while (out_pos < out_size - 1 && out_shape[out_pos + 1] == 1) { out_pos++; } } @@ -115,7 +117,7 @@ std::vector InferReduceAxes(const Array& input_shape, if (in_pos == -1) { return std::vector(); } - for (size_t i = out_start; i < out_pos + 1; i++) { + for (size_t i = out_start; i < static_cast(out_pos) + 1; i++) { out_axes.push_back(i + 1); } start = in_pos + 1; @@ -225,7 +227,7 @@ InferLayoutOutput ForwardInferLayoutCommon(const Call& call, } std::vector output_layouts; const auto& sinfo = GetStructInfo(call); - if (sinfo.as()) { + if (sinfo->IsInstance()) { output_layouts.push_back(layout_hint); } else if (const auto* tuple_sinfo = sinfo.as()) { for (size_t i = 0; i < tuple_sinfo->fields.size(); i++) { @@ -955,19 +957,19 @@ class LayoutInfer : public ExprVisitor { void BackwardInfer() { for (size_t e_idx = ordered_exprs_.size(); e_idx > 0; e_idx--) { const Expr& expr = ordered_exprs_[e_idx - 1]; - if (const auto* t_node = expr.as()) { + if (expr->IsInstance()) { continue; } - if (const auto* t_node = expr.as()) { + if (expr->IsInstance()) { continue; } - if (!expr.as()) { + if (!expr->IsInstance()) { continue; } const Call& call = Downcast(expr); size_t infered_num = 0; for (const auto& arg : call->args) { - if (arg.as() && var_map_.count(Downcast(arg))) { + if (arg->IsInstance() && var_map_.count(Downcast(arg))) { if (LayoutUtils::LayoutInfered(var_map_[Downcast(arg)]) > 0) { infered_num++; } @@ -975,8 +977,8 @@ class LayoutInfer : public ExprVisitor { infered_num++; } } - if (call->args.size() == 0 || infered_num == call->args.size() || !call->op.as() || - LayoutUtils::HasUnknownDimTensor(call->args)) { + if (call->args.size() == 0 || infered_num == call->args.size() || + !call->op->IsInstance() || LayoutUtils::HasUnknownDimTensor(call->args)) { continue; } const OpNode* op_node = call->op.as(); @@ -1013,7 +1015,7 @@ class LayoutInfer : public ExprVisitor { void SetInputLayouts(const Array& input_layouts, const Call& call) { if (input_layouts.size() == call->args.size()) { for (size_t i = 0; i < input_layouts.size(); i++) { - if (call->args[i].as()) { + if (call->args[i]->IsInstance()) { const auto& var = Downcast(call->args[i]); var_layout_map_[var] = input_layouts[i]; if (var_map_.count(var)) { @@ -1058,7 +1060,7 @@ class LayoutInfer : public ExprVisitor { if (LayoutUtils::LayoutInfered(call)) { infer_outputs = false; } - if (call->args.size() == 0 || !call->op.as() || + if (call->args.size() == 0 || !call->op->IsInstance() || LayoutUtils::HasUnknownDimTensor(call->args)) { infer_outputs = false; } diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc index 66bbf8cc9790..7ecff876f278 100644 --- a/src/contrib/msc/core/utils.cc +++ b/src/contrib/msc/core/utils.cc @@ -152,7 +152,7 @@ const String StringUtils::ToString(const runtime::ObjectRef& obj) { } bool StringUtils::CompareArrays(const Array& left, const Array& right, int size) { - if (left.size() == right.size() == 0) { + if (left.size() == right.size() && left.size() == 0) { return true; } if (size == -1 && left.size() != right.size()) { @@ -163,10 +163,10 @@ bool StringUtils::CompareArrays(const Array& left, const Array& } size = left.size(); ICHECK_GT(size, 0) << "Positive size should be given, get " << size; - if (size > left.size() || size > right.size()) { + if (size > static_cast(left.size()) || size > static_cast(right.size())) { return false; } - for (size_t i = 0; i < size; i++) { + for (size_t i = 0; i < static_cast(size); i++) { if (left[i] != right[i]) { return false; } diff --git a/tests/scripts/task_config_build_cpu.sh b/tests/scripts/task_config_build_cpu.sh index 9eda0d74d41c..0d6c0e2cae46 100755 --- a/tests/scripts/task_config_build_cpu.sh +++ b/tests/scripts/task_config_build_cpu.sh @@ -57,3 +57,4 @@ echo set\(USE_CCACHE OFF\) >> config.cmake echo set\(USE_ETHOSU OFF\) >> config.cmake echo set\(USE_UMA ON\) >> config.cmake echo set\(SUMMARIZE ON\) >> config.cmake +echo set\(USE_MSC ON\) >> config.cmake From 7489231f7f05be93000da32b19334a633dbb4d87 Mon Sep 17 00:00:00 2001 From: Archermmt Date: Sat, 12 Aug 2023 09:02:03 +0800 Subject: [PATCH 8/8] move test to task_python_relax --- ci/jenkins/unity_jenkinsfile.groovy | 1 - tests/scripts/unity/task_python_msc.sh | 33 ------------------------ tests/scripts/unity/task_python_relax.sh | 3 +++ 3 files changed, 3 insertions(+), 34 deletions(-) delete mode 100644 tests/scripts/unity/task_python_msc.sh diff --git a/ci/jenkins/unity_jenkinsfile.groovy b/ci/jenkins/unity_jenkinsfile.groovy index 13ce378385bf..99485f7c557f 100644 --- a/ci/jenkins/unity_jenkinsfile.groovy +++ b/ci/jenkins/unity_jenkinsfile.groovy @@ -318,7 +318,6 @@ stage('Build and Test') { sh "${docker_run} ${ci_gpu} ./tests/scripts/task_config_build_gpu.sh build" make("${ci_gpu}", 'build', '-j2') sh "${docker_run} ${ci_gpu} ./tests/scripts/unity/task_python_relax_gpuonly.sh" - sh "${docker_run} ${ci_gpu} ./tests/scripts/unity/task_python_msc.sh" } } }, diff --git a/tests/scripts/unity/task_python_msc.sh b/tests/scripts/unity/task_python_msc.sh deleted file mode 100644 index eed1f79f83fc..000000000000 --- a/tests/scripts/unity/task_python_msc.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/usr/bin/env bash -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -set -euxo pipefail - -source tests/scripts/setup-pytest-env.sh -# to avoid openblas threading error -export TVM_BIND_THREADS=0 -export OMP_NUM_THREADS=1 - -export TVM_TEST_TARGETS="llvm;cuda" - -find . -type f -path "*.pyc" | xargs rm -f - -# Rebuild cython -make cython3 - -run_pytest cython python-contrib-msc tests/python/contrib/test_msc diff --git a/tests/scripts/unity/task_python_relax.sh b/tests/scripts/unity/task_python_relax.sh index b6b70ab457ec..121ba1389ae5 100755 --- a/tests/scripts/unity/task_python_relax.sh +++ b/tests/scripts/unity/task_python_relax.sh @@ -36,3 +36,6 @@ TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/dlight # python3 ./apps/relax_examples/mlp.py # python3 ./apps/relax_examples/nn_module.py # python3 ./apps/relax_examples/resnet.py + +# Test for MSC +pytest tests/python/contrib/test_msc