From ed3790e979b4f0afa86890d2e597783deb67b85e Mon Sep 17 00:00:00 2001 From: Archermmt Date: Tue, 23 Jan 2024 12:36:54 +0800 Subject: [PATCH] [Unity][MSC][M4.1] Add plugin && plugin_builder, enable build and test in different frameworks (#16397) * add plugin building * minor fix --- .../tvm/contrib/msc/core/codegen/codegen.py | 33 +- .../msc/framework/tvm/codegen/codegen.py | 30 +- python/tvm/contrib/msc/plugin/__init__.py | 19 + python/tvm/contrib/msc/plugin/_ffi_api.py | 21 + python/tvm/contrib/msc/plugin/build.py | 286 ++++ .../contrib/msc/plugin/codegen/__init__.py | 19 + .../tvm/contrib/msc/plugin/codegen/codegen.py | 319 +++++ .../tvm/contrib/msc/plugin/codegen/sources.py | 1157 +++++++++++++++++ python/tvm/contrib/msc/plugin/op/__init__.py | 17 + python/tvm/contrib/msc/plugin/op/_ffi_api.py | 21 + python/tvm/contrib/msc/plugin/register.py | 85 ++ python/tvm/contrib/msc/plugin/utils.py | 109 ++ src/contrib/msc/core/ir/plugin.cc | 327 +++++ src/contrib/msc/core/ir/plugin.h | 686 ++++++++++ src/contrib/msc/core/utils.cc | 41 + src/contrib/msc/core/utils.h | 28 + src/contrib/msc/plugin/base_codegen.h | 674 ++++++++++ src/contrib/msc/plugin/codegen_utils.h | 75 ++ src/contrib/msc/plugin/tensorrt_codegen.cc | 901 +++++++++++++ src/contrib/msc/plugin/tensorrt_codegen.h | 134 ++ src/contrib/msc/plugin/torch_codegen.cc | 510 ++++++++ src/contrib/msc/plugin/torch_codegen.h | 137 ++ src/contrib/msc/plugin/tvm_codegen.cc | 411 ++++++ src/contrib/msc/plugin/tvm_codegen.h | 124 ++ tests/python/contrib/test_msc/test_plugin.py | 309 +++++ 25 files changed, 6438 insertions(+), 35 deletions(-) create mode 100644 python/tvm/contrib/msc/plugin/__init__.py create mode 100644 python/tvm/contrib/msc/plugin/_ffi_api.py create mode 100644 python/tvm/contrib/msc/plugin/build.py create mode 100644 python/tvm/contrib/msc/plugin/codegen/__init__.py create mode 100644 python/tvm/contrib/msc/plugin/codegen/codegen.py create mode 100644 python/tvm/contrib/msc/plugin/codegen/sources.py create mode 100644 python/tvm/contrib/msc/plugin/op/__init__.py create mode 100644 python/tvm/contrib/msc/plugin/op/_ffi_api.py create mode 100644 python/tvm/contrib/msc/plugin/register.py create mode 100644 python/tvm/contrib/msc/plugin/utils.py create mode 100644 src/contrib/msc/core/ir/plugin.cc create mode 100644 src/contrib/msc/core/ir/plugin.h create mode 100644 src/contrib/msc/plugin/base_codegen.h create mode 100644 src/contrib/msc/plugin/codegen_utils.h create mode 100644 src/contrib/msc/plugin/tensorrt_codegen.cc create mode 100644 src/contrib/msc/plugin/tensorrt_codegen.h create mode 100644 src/contrib/msc/plugin/torch_codegen.cc create mode 100644 src/contrib/msc/plugin/torch_codegen.h create mode 100644 src/contrib/msc/plugin/tvm_codegen.cc create mode 100644 src/contrib/msc/plugin/tvm_codegen.h create mode 100644 tests/python/contrib/test_msc/test_plugin.py diff --git a/python/tvm/contrib/msc/core/codegen/codegen.py b/python/tvm/contrib/msc/core/codegen/codegen.py index e9013ba66153..8ffaf9dd5fa1 100644 --- a/python/tvm/contrib/msc/core/codegen/codegen.py +++ b/python/tvm/contrib/msc/core/codegen/codegen.py @@ -169,21 +169,18 @@ def relay_to_relax( ] # pylint: disable=unused-argument - def _bind_weights(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRModule: - return BindParams("main", weights)(mod) - - mod = codegen.load(inputs, post_load=_bind_weights) - - mod = tvm.ir.transform.Sequential( - [ - # The canonicalization of relax variable bindings is not required - # for correctness. It does, however, remove trivial `x = y` - # bindings, preventing test cases from depending on their - # presence. - tvm.relax.transform.CanonicalizeBindings(), - tvm.relax.transform.ConvertToDataflow(min_size=1), - ], - name="tvm.contrib.msc.core.codegen.relay_to_relax_postproc", - )(mod) - - return mod + def _post_proc(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRModule: + mod = BindParams("main", weights)(mod) + return tvm.ir.transform.Sequential( + [ + # The canonicalization of relax variable bindings is not required + # for correctness. It does, however, remove trivial `x = y` + # bindings, preventing test cases from depending on their + # presence. + tvm.relax.transform.CanonicalizeBindings(), + tvm.relax.transform.ConvertToDataflow(min_size=1), + ], + name="tvm.contrib.msc.core.codegen.relay_to_relax_postproc", + )(mod) + + return codegen.load(inputs, post_load=_post_proc) diff --git a/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py b/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py index 489ca0a2b528..c344b9260644 100644 --- a/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py +++ b/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py @@ -65,24 +65,20 @@ def _save_weights(folder: msc_utils.MSCDirectory): f_params.write(tvm.runtime.save_param_dict(weights)) # pylint: disable=unused-argument - def _bind_weights(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRModule: + def _post_proc(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRModule: if weights: mod = BindParams("main", weights)(mod) - return mod + return tvm.ir.transform.Sequential( + [ + # The canonicalization of relax variable bindings is not required + # for correctness. It does, however, remove trivial `x = y` + # bindings, preventing test cases from depending on their + # presence. + tvm.relax.transform.CanonicalizeBindings(), + tvm.relax.transform.ConvertToDataflow(min_size=1), + ], + name="tvm.contrib.msc.framework.tvm.codegen.to_relax_postproc", + )(mod) codegen = CodeGen(graph, _ffi_api.GetRelaxSources, codegen_config, print_config, build_folder) - mod = codegen.load(inputs, pre_load=_save_weights, post_load=_bind_weights) - - mod = tvm.ir.transform.Sequential( - [ - # The canonicalization of relax variable bindings is not required - # for correctness. It does, however, remove trivial `x = y` - # bindings, preventing test cases from depending on their - # presence. - tvm.relax.transform.CanonicalizeBindings(), - tvm.relax.transform.ConvertToDataflow(min_size=1), - ], - name="tvm.contrib.msc.framework.tvm.codegen.to_relax_postproc", - )(mod) - - return mod + return codegen.load(inputs, pre_load=_save_weights, post_load=_post_proc) diff --git a/python/tvm/contrib/msc/plugin/__init__.py b/python/tvm/contrib/msc/plugin/__init__.py new file mode 100644 index 000000000000..53b4774db1b5 --- /dev/null +++ b/python/tvm/contrib/msc/plugin/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.plugin""" + +from .build import * diff --git a/python/tvm/contrib/msc/plugin/_ffi_api.py b/python/tvm/contrib/msc/plugin/_ffi_api.py new file mode 100644 index 000000000000..0e12c29242d1 --- /dev/null +++ b/python/tvm/contrib/msc/plugin/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.plugin._ffi_api""" + +import tvm._ffi + +tvm._ffi._init_api("msc.plugin", __name__) diff --git a/python/tvm/contrib/msc/plugin/build.py b/python/tvm/contrib/msc/plugin/build.py new file mode 100644 index 000000000000..b7f3cee9fc0b --- /dev/null +++ b/python/tvm/contrib/msc/plugin/build.py @@ -0,0 +1,286 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.plugin.build""" + +import os +import sys +import subprocess + +from typing import List, Dict, Any, Optional +from tvm.contrib.msc.core import utils as msc_utils +from tvm.contrib.msc.plugin.codegen import get_codegen +from .register import register_plugin + + +def _build_plugins( + plugins: Dict[str, dict], + frameworks: List[str], + workspace: msc_utils.MSCDirectory = None, + codegen_config: Optional[Dict[str, str]] = None, + cpp_print_config: Optional[Dict[str, str]] = None, + py_print_config: Optional[Dict[str, str]] = None, + externs_dir: msc_utils.MSCDirectory = None, + on_debug: bool = False, +): + """Build the plugins + + Parameters + ---------- + plugins: dict + The plugins define. + frameworks: list + The frameworks for plugin. + workspace: MSCDirectory + The workspace folder. + codegen_config: dict + The config to generate code. + cpp_print_config: dict + The config to print cpp code. + py_print_config: dict + The config to print python code. + externs_dir: MSCDirectory + The extern sources folder. + on_debug: bool + Whether to debug the building. + """ + + workspace = workspace or msc_utils.msc_dir("msc_plugin") + + # register the plugins + extern_sources, extern_libs, ops_info = {}, {}, {} + for name, plugin in plugins.items(): + sources, libs, info = register_plugin(name, plugin, externs_dir) + extern_sources.update(sources) + extern_libs.update(libs) + ops_info[name] = info + # build plugins for frameworks + codegens = {} + for framework in frameworks: + codegen = get_codegen( + framework, + workspace, + codegen_config, + cpp_print_config=cpp_print_config, + py_print_config=py_print_config, + extern_sources=extern_sources, + extern_libs=extern_libs, + on_debug=on_debug, + ) + if not codegen.libs_built(): + codegen.build_libs() + if codegen.need_manager and not codegen.manager_built(): + codegen.build_manager(ops_info) + codegens[framework] = codegen + return codegens + + +def build_plugins( + plugins: Dict[str, dict], + frameworks: List[str], + workspace: msc_utils.MSCDirectory = None, + codegen_config: Optional[Dict[str, str]] = None, + cpp_print_config: Optional[Dict[str, str]] = None, + py_print_config: Optional[Dict[str, str]] = None, + externs_dir: msc_utils.MSCDirectory = None, + on_debug: bool = False, +) -> Dict[str, Any]: + """Build the plugins and load plugin manager + + Parameters + ---------- + plugins: dict + The plugins define. + frameworks: list + The frameworks for plugin. + workspace: MSCDirectory + The workspace folder. + codegen_config: dict + The config to generate code. + cpp_print_config: dict + The config to print cpp code. + py_print_config: dict + The config to print python code. + externs_dir: MSCDirectory + The extern sources folder. + on_debug: bool + Whether to debug the building. + + Returns + ------- + managers: dict + The plugin managers. + """ + + codegens = _build_plugins( + plugins, + frameworks, + workspace, + codegen_config=codegen_config, + cpp_print_config=cpp_print_config, + py_print_config=py_print_config, + externs_dir=externs_dir, + on_debug=on_debug, + ) + managers = {} + for name, codegen in codegens.items(): + manager_file = codegen.manager_folder.relpath("manager.py") + manager_cls = msc_utils.load_callable(manager_file + ":PluginManager") + managers[name] = manager_cls(codegen.output_folder.path) + return managers + + +def pack_plugins( + plugins: Dict[str, dict], + frameworks: List[str], + project_name: str = "msc_plugin", + codegen_config: Optional[Dict[str, str]] = None, + cpp_print_config: Optional[Dict[str, str]] = None, + py_print_config: Optional[Dict[str, str]] = None, + externs_dir: msc_utils.MSCDirectory = None, + setup_config: Optional[Dict[str, str]] = None, + on_debug: bool = False, +) -> str: + """Build the plugins and build to wheel + + Parameters + ---------- + plugins: dict + The plugins define. + frameworks: list + The frameworks for plugin. + project_name: str + The project name + codegen_config: dict + The config to generate code. + cpp_print_config: dict + The config to print cpp code. + py_print_config: dict + The config to print python code. + externs_dir: MSCDirectory + The extern sources folder. + setup_config: dict + The config to setup wheel. + on_debug: bool + Whether to debug the building. + + Returns + ------- + wheel_path: str + The file path of wheel. + """ + + project_dir = msc_utils.msc_dir(project_name) + workspace = project_dir.create_dir(project_name) + codegens = _build_plugins( + plugins, + frameworks, + workspace, + codegen_config=codegen_config, + cpp_print_config=cpp_print_config, + py_print_config=py_print_config, + externs_dir=externs_dir, + on_debug=on_debug, + ) + # add init files + init_code = """# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from .manager import * +""" + with open(workspace.relpath("__init__.py"), "w") as f: + f.write(init_code) + for name in codegens: + with open(workspace.create_dir(name).relpath("__init__.py"), "w") as f: + f.write(init_code) + + # add setup file + if setup_config: + setup_config_str = "\n " + "\n ".join( + ["{} = {},".format(k, v) for k, v in setup_config.items()] + ) + else: + setup_config_str = "" + setup_code = """ +import os +import shutil + +from setuptools import find_packages, setup +from setuptools.dist import Distribution + +project_name = "{0}" +data_files = [] +for framework in [{2}]: + for folder in ["lib", "include"]: + src_path = os.path.join(project_name, framework, folder) + data_files.append( + ( + os.path.join(project_name, framework, folder), + [os.path.join(src_path, f) for f in os.listdir(src_path)], + ), + ) + +class BinaryDistribution(Distribution): + def has_ext_modules(self): + return True + + def is_pure(self): + return False + +setup( + name="{0}"{1}, + packages=find_packages(), + distclass=BinaryDistribution, + data_files=data_files +) + +shutil.rmtree("build") +shutil.rmtree("{0}.egg-info") +""".format( + project_name, setup_config_str, ",".join(['"{}"'.format(f) for f in frameworks]) + ) + with open(project_dir.relpath("setup.py"), "w") as f: + f.write(setup_code) + + # build the wheel + with project_dir: + command = "{} setup.py bdist_wheel".format(sys.executable) + with open("build.log", "w") as log_f: + process = subprocess.Popen(command, stdout=log_f, stderr=log_f, shell=True) + process.wait() + assert ( + process.returncode == 0 + ), "Failed to build wheel under {}, check build.log for detail".format(os.getcwd()) + dist_dir = project_dir.create_dir("dist") + files = list(dist_dir.listdir()) + assert len(files) == 1 and files[0].endswith( + ".whl" + ), "Failed to build wheel, no .whl found @ " + str(dist_dir.path) + return dist_dir.relpath(files[0]) diff --git a/python/tvm/contrib/msc/plugin/codegen/__init__.py b/python/tvm/contrib/msc/plugin/codegen/__init__.py new file mode 100644 index 000000000000..fbc0b0fed8a0 --- /dev/null +++ b/python/tvm/contrib/msc/plugin/codegen/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.plugin.codegen""" + +from .codegen import * diff --git a/python/tvm/contrib/msc/plugin/codegen/codegen.py b/python/tvm/contrib/msc/plugin/codegen/codegen.py new file mode 100644 index 000000000000..a8cad6c725bf --- /dev/null +++ b/python/tvm/contrib/msc/plugin/codegen/codegen.py @@ -0,0 +1,319 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.codegen.codegen""" + +import os +import subprocess +from typing import Dict, List, Optional + +import tvm +from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.plugin import _ffi_api +from tvm.contrib.msc.core import utils as msc_utils +from .sources import get_plugin_sources + + +class BasePluginCodeGen(object): + """Manager class to generate codes and build plugin + + Parameters + ---------- + workspace: MSCDirectory + The workspace folder. + codegen_config: dict + The config to generate code. + cpp_print_config: dict + The config to print cpp code. + py_print_config: dict + The config to print python code. + extern_sources: dict + The depend source files. + extern_libs: dict + The depend lib files. + on_debug: bool + Whether to debug the building. + """ + + def __init__( + self, + workspace: msc_utils.MSCDirectory, + codegen_config: Optional[Dict[str, str]] = None, + cpp_print_config: Optional[Dict[str, str]] = None, + py_print_config: Optional[Dict[str, str]] = None, + extern_sources: Dict[str, str] = None, + extern_libs: Dict[str, str] = None, + on_debug: bool = False, + ): + self._codegen_config = msc_utils.copy_dict(codegen_config) + self._cpp_print_config = msc_utils.dump_dict(cpp_print_config) + self._py_print_config = msc_utils.dump_dict(py_print_config) + self._build_folder = workspace.create_dir( + "source_" + self.framework, keep_history=on_debug, cleanup=not on_debug + ) + self._output_folder = workspace.create_dir(self.framework) + self._extern_sources = extern_sources or {} + self._extern_libs = extern_libs or {} + self.setup() + + def setup(self): + """Set up the codegen""" + + self._lib_folder = self._output_folder.create_dir("lib") + self._manager_folder = self._output_folder + self._libs = [os.path.basename(l) for l in self._extern_libs.values()] + self._libs.extend([os.path.basename(l) for l in self._lib_folder.listdir()]) + self._project_name = "msc_{}_plugin".format(self.framework) + self._codegen_config.update( + { + "install_dir": self._output_folder.path, + "project_name": self._project_name, + "version": msc_utils.get_version(self.framework), + } + ) + + def libs_built(self) -> bool: + """Check if the libs are built + + Returns + ------- + libs_built: bool + Whether libs are built. + """ + + return any(self._project_name in f for f in self._lib_folder.listdir()) + + def build_libs(self) -> List[str]: + """Generate source and build the lib + + Returns + ------- + paths: list + The lib file paths. + """ + + codegen_config = msc_utils.dump_dict(self._codegen_config) + sources = self.source_getter(codegen_config, self._cpp_print_config, "build") + with self._build_folder as folder: + # add depends + with folder.create_dir("src") as src_folder: + for name, file in self._extern_sources.items(): + src_folder.copy(file, name) + for name, source in get_plugin_sources().items(): + src_folder.add_file(name, source) + for name, source in sources.items(): + if name == "CMakeLists.txt": + folder.add_file(name, source) + else: + src_folder.add_file(name, source) + with folder.create_dir("build"): + command = "cmake ../ && make" + with open("codegen.log", "w") as log_f: + process = subprocess.Popen(command, stdout=log_f, stderr=log_f, shell=True) + process.wait() + assert ( + process.returncode == 0 + ), "Failed to build plugin under {}, check codegen.log for detail".format( + os.getcwd() + ) + self._libs.extend([os.path.basename(l) for l in self._lib_folder.listdir()]) + return self._lib_folder.listdir(as_abs=True) + + def manager_built(self) -> bool: + """Check if the manager are built + + Returns + ------- + manager_built: bool + Whether manager is built. + """ + + return os.path.isfile(self._manager_folder.relpath("manager.py")) + + def build_manager(self, ops_info: dict) -> List[str]: + """Generate manager source for plugin + + Parameters + ---------- + ops_info: dict + The info of ops. + + Returns + ------- + paths: list + The manager file paths. + """ + + self._codegen_config["libs"] = self._libs + self._codegen_config["ops_info"] = {n: msc_utils.dump_dict(i) for n, i in ops_info.items()} + codegen_config = msc_utils.dump_dict(self._codegen_config) + sources = self.source_getter(codegen_config, self._py_print_config, "manager") + manager_files = [] + with self._manager_folder as folder: + for name, source in sources.items(): + manager_files.append(folder.add_file(name, source)) + return manager_files + + @property + def source_getter(self): + raise NotImplementedError("source_getter is not supported for Base codegen") + + @property + def need_manager(self): + return True + + @property + def framework(self): + return MSCFramework.MSC + + @property + def output_folder(self): + return self._output_folder + + @property + def lib_folder(self): + return self._lib_folder + + @property + def manager_folder(self): + return self._manager_folder + + +class TVMPluginCodegen(BasePluginCodeGen): + """Plugin codegen for tvm""" + + def setup(self): + """Set up the codegen""" + + super().setup() + tvm_root = os.path.dirname(os.path.dirname(tvm.__path__[0])) + self._codegen_config.update( + {"need_convert": False, "with_runtime": True, "tvm_root": tvm_root} + ) + + @property + def source_getter(self): + return _ffi_api.GetTVMPluginSources + + @property + def framework(self): + return MSCFramework.TVM + + +class TorchPluginCodegen(BasePluginCodeGen): + """Plugin codegen for torch""" + + def setup(self): + """Set up the codegen""" + # pylint: disable=import-outside-toplevel + import torch.utils + + super().setup() + self._codegen_config.update( + { + "need_convert": True, + "with_runtime": False, + "torch_prefix": torch.utils.cmake_prefix_path, + } + ) + + @property + def source_getter(self): + return _ffi_api.GetTorchPluginSources + + @property + def framework(self): + return MSCFramework.TORCH + + +class TensorRTPluginCodegen(BasePluginCodeGen): + """Plugin codegen for tensorrt""" + + def setup(self): + """Set up the codegen""" + # pylint: disable=import-outside-toplevel + from tvm.contrib.msc.framework.tensorrt import _ffi_api as _trt_api + + super().setup() + self._codegen_config.update( + { + "need_convert": False, + "with_runtime": False, + "tensorrt_root": _trt_api.GetTensorRTRoot(), + } + ) + + @property + def source_getter(self): + return _ffi_api.GetTensorRTPluginSources + + @property + def framework(self): + return MSCFramework.TENSORRT + + +def get_codegen( + framework: str, + workspace: msc_utils.MSCDirectory, + codegen_config: Optional[Dict[str, str]] = None, + cpp_print_config: Optional[Dict[str, str]] = None, + py_print_config: Optional[Dict[str, str]] = None, + extern_sources: Dict[str, str] = None, + extern_libs: Dict[str, str] = None, + on_debug: bool = False, +): + """Create codegen for framework + + Parameters + ---------- + framework: str + THe framework for the plugin. + workspace: MSCDirectory + The workspace folder. + codegen_config: dict + The config to generate code. + cpp_print_config: dict + The config to print cpp code. + py_print_config: dict + The config to print python code. + extern_sources: dict + The depend source files. + extern_libs: dict + The depend lib files. + on_debug: bool + Whether to debug the building. + """ + + codegen_cls = None + if framework == MSCFramework.TVM: + codegen_cls = TVMPluginCodegen + elif framework == MSCFramework.TORCH: + codegen_cls = TorchPluginCodegen + elif framework == MSCFramework.TENSORRT: + codegen_cls = TensorRTPluginCodegen + else: + raise NotImplementedError( + "framework {} is not support for plugin codegen".format(framework) + ) + return codegen_cls( + workspace, + codegen_config=codegen_config, + cpp_print_config=cpp_print_config, + py_print_config=py_print_config, + extern_sources=extern_sources, + extern_libs=extern_libs, + on_debug=on_debug, + ) diff --git a/python/tvm/contrib/msc/plugin/codegen/sources.py b/python/tvm/contrib/msc/plugin/codegen/sources.py new file mode 100644 index 000000000000..1ea95a958f7a --- /dev/null +++ b/python/tvm/contrib/msc/plugin/codegen/sources.py @@ -0,0 +1,1157 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.plugin.codegen.sources""" + +from typing import Dict + + +def get_plugin_base_h_code() -> str: + """Create plugin base header file codes + + Returns + ------- + source: str + The plugin base header source. + """ + + return """#ifndef TVM_CONTRIB_MSC_UTILS_PLUGIN_BASE_H_ +#define TVM_CONTRIB_MSC_UTILS_PLUGIN_BASE_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace contrib { +namespace msc { +namespace plugin { + +typedef enum { + kUINT8 = 0, + kINT8 = 1, + kINT16 = 2, + kINT32 = 3, + kINT64 = 4, + kFLOAT16 = 5, + kFLOAT32 = 6, + kFLOAT64 = 7, + kUNKNOWN = 8, +} MetaDataType; + +class MetaShape { + public: + MetaShape() { shape_.resize(0); } + + MetaShape(const std::vector& shape) { + for (auto d : shape) { + shape_.push_back(d); + } + } + + template + void SetShape(const std::vector& shape) { + for (auto d : shape) { + shape_.push_back(static_cast(d)); + } + } + + template + void SetDim(int index, T dim) { + int valid_index = index < 0 ? shape_.size() + index : index; + if (valid_index >= shape_.size()) { + std::string err = + std::to_string(index) + " out of dims size " + std::to_string(shape_.size()); + throw std::runtime_error(err); + } + shape_[valid_index] = dim; + } + + template + const std::vector GetShape() const { + std::vector shape; + for (auto d : shape_) { + shape.push_back(d); + } + return shape; + } + + inline int64_t DimAt(int index) const { + int valid_index = index < 0 ? shape_.size() + index : index; + if (valid_index >= shape_.size()) { + std::string err = + std::to_string(index) + " out of dims size " + std::to_string(shape_.size()); + throw std::runtime_error(err); + } + return shape_[valid_index]; + } + + inline size_t ndim() const { return shape_.size(); } + + inline const std::vector shape() const { return shape_; } + + inline size_t size() const { + size_t size = 1; + for (auto d : shape_) { + assert(d > 0 && "Can not compute static size with unknow dim"); + size *= d; + } + return size; + } + + inline int64_t operator[](int index) const { return DimAt(index); } + + friend std::ostream& operator<<(std::ostream& out, const MetaShape& shape) { + for (size_t i = 0; i < shape.ndim(); i++) { + out << shape.DimAt(i) << (1 < shape.ndim() ? "" : ","); + } + return out; + } + + private: + std::vector shape_; +}; + +class MetaLayoutAxis { + public: + MetaLayoutAxis(const char name, size_t factor = 0) : factor_(factor) { + name_ = (factor == 0 ? "" : std::to_string(factor)) + std::string(1, name); + } + + MetaLayoutAxis(const std::string& name) { + if (name.size() == 1) { + factor_ = 0; + name_ = name; + } else { + factor_ = std::stoi(name.substr(1)); + name_ = name.substr(0, 1); + } + } + + inline const std::string name() const { return name_; } + + inline size_t factor() const { return factor_; } + + private: + std::string name_; + size_t factor_; +}; + +class MetaLayout { + public: + MetaLayout() {} + + MetaLayout(const std::string& name) : name_(name) { + int factor = 0; + for (char c : name) { + if (c >= 'A' && c <= 'Z') { + assert(factor == 0 && "Upper layout axis do not accept factor"); + MetaLayoutAxis axis(c); + axes_.push_back(axis); + } else if (c >= 'a' && c <= 'z') { + assert(factor > 0 && "Lower layout axis should has factor"); + MetaLayoutAxis axis(c, factor); + axes_.push_back(axis); + factor = 0; + } else if (c >= '0' && c <= '9') { + assert(factor >= 0 && "Factor number should between 0 and 9"); + factor = factor * 10 + c - '0'; + } else { + throw std::runtime_error("Unexpected layout axis " + name); + } + } + CheckValid(); + } + + MetaLayout(const std::vector& axes) : axes_(axes) { + name_ = ""; + for (auto a : axes_) { + name_ += (a.factor() == 0 ? "" : std::to_string(a.factor())) + a.name(); + } + CheckValid(); + }; + + void CheckValid() { + std::set recorded_axes; + for (auto a : axes_) { + auto axis_name = a.name(); + assert(!recorded_axes.count(axis_name) && ("Has duplicate layout axis in " + name_).c_str()); + recorded_axes.insert(axis_name); + } + } + + inline const MetaLayoutAxis AxisAt(int index) const { + int valid_index = index < 0 ? axes_.size() + index : index; + if (valid_index >= axes_.size()) { + std::string err = std::to_string(index) + " out of axes size " + std::to_string(axes_.size()); + throw std::runtime_error(err); + } + return axes_[valid_index]; + } + + inline MetaLayoutAxis operator[](int index) { return AxisAt(index); } + + inline size_t ndim() const { return axes_.size(); } + + inline std::string name() const { return name_; } + + friend std::ostream& operator<<(std::ostream& out, const MetaLayout& layout) { + out << layout.name(); + return out; + } + + private: + std::string name_; + std::vector axes_; +}; + +class MetaTensor { + public: + MetaTensor() {} + + MetaTensor(const MetaShape& shape, const MetaDataType& data_type, + const MetaLayout& layout = MetaLayout()) + : shape_(shape), data_type_(data_type), layout_(layout) {} + + inline const MetaShape shape() const { return shape_; } + + inline MetaDataType data_type() const { return data_type_; } + + inline const std::vector meta_shape() const { return shape_.shape(); } + + inline const MetaLayout layout() const { return layout_; } + + inline const std::string layout_name() const { return layout_.name(); } + + inline size_t ndim() const { return shape_.ndim(); } + + inline size_t size(bool count_batch = true) const { + if (count_batch) { + size_t batch_dim = 0; + for (size_t i = 0; i < layout_.ndim(); i++) { + if (layout_.AxisAt(i).name() == "N") { + batch_dim = i; + } + } + return shape_.size() / shape_.shape()[batch_dim]; + } + return shape_.size(); + } + + inline MetaLayoutAxis AxisAt(int index) const { return layout_.AxisAt(index); } + + inline int AxisOf(const std::string& axis) const { + for (size_t i = 0; i < layout_.ndim(); i++) { + if (layout_.AxisAt(i).name() == axis) { + return i; + } + } + return -1; + } + + inline int64_t DimAt(int index) const { return shape_.DimAt(index); } + + inline int64_t DimAt(const std::string& axis) const { + int idx = AxisOf(axis); + if (idx >= 0) { + return shape_.DimAt(idx); + } + throw std::runtime_error("Can not find dim for " + axis); + } + + friend std::ostream& operator<<(std::ostream& out, const MetaTensor& tensor) { + out << "tensor : <" << tensor.shape() << ">, (" << tensor.layout() << ")"; + return out; + } + + private: + MetaShape shape_; + MetaDataType data_type_; + MetaLayout layout_; +}; + +template +class DataTensor : public MetaTensor { + public: + DataTensor(const MetaShape shape, const MetaDataType& data_type, const MetaLayout layout, T* data) + : MetaTensor(shape, data_type, layout) { + data_ = data; + } + + DataTensor(const MetaShape shape, const MetaDataType& data_type, const MetaLayout layout, + const T* data) + : MetaTensor(shape, data_type, layout) { + data_ = const_cast(data); + } + + T* data() const { return data_; } + + const T* const_data() const { return data_; } + + private: + T* data_{nullptr}; +}; + +} // namespace plugin +} // namespace msc +} // namespace contrib +} // namespace tvm +#endif // TVM_CONTRIB_MSC_UTILS_PLUGIN_BASE_H_ +""" + + +def _get_common_utils() -> str: + """Get the utils for common + + Returns + ------- + source: str + The plugin utils for common. + """ + + return """class SerializeUtils { + public: + // Helper function for serializing plugin attrs + template + static const std::string ToString(const T& value) { + return std::to_string(value); + } + + static std::string ToString(const std::string& value) { return value; } + + template + static std::string ToString(const std::vector& value) { + std::string str = std::to_string(value.size()); + for (const auto& v : value) { + str += "," + std::to_string(v); + } + return str; + } + + static void FromString(const std::string& src, std::string& target) { target = src; } + + static void FromString(const std::string& src, bool& target) { + target = std::stoi(src) > 0 ? true : false; + } + + static void FromString(const std::string& src, int& target) { target = std::stoi(src); } + + static void FromString(const std::string& src, size_t& target) { target = std::stoi(src); } + + static void FromString(const std::string& src, long& target) { target = std::stol(src); } + + static void FromString(const std::string& src, float& target) { target = std::stod(src); } + + static void FromString(const std::string& src, double& target) { target = std::stof(src); } + + template + static void FromString(const std::string& src, std::vector& target) { + std::string left_str = src; + int pos = left_str.find(","); + if (pos == std::string::npos) { + return; + } + assert(pos > 0); + size_t src_size; + FromString(left_str.substr(0, pos), src_size); + target.resize(src_size); + for (size_t i = 0; i < src_size; i++) { + pos = left_str.find(","); + left_str = left_str.substr(pos + 1); + FromString(left_str, target[i]); + } + } + + static void FromString(const std::string& src, std::vector& target) { + std::vector values; + FromString(src, values); + target.resize(values.size()); + for (size_t i = 0; i < values.size(); i++) { + target[i] = values[i] > 0 ? true : false; + } + } +}; + +class DataUtils { + public: + static MetaDataType ToMetaType(const std::string& name) { + MetaDataType dtype; + if (name == "int8") { + dtype = MetaDataType::kINT8; + } else if (name == "uint8" || name == "char") { + dtype = MetaDataType::kUINT8; + } else if (name == "int16") { + dtype = MetaDataType::kINT16; + } else if (name == "int32" || name == "int") { + dtype = MetaDataType::kINT32; + } else if (name == "int64" || name == "long") { + dtype = MetaDataType::kINT64; + } else if (name == "float16" || name == "half") { + dtype = MetaDataType::kFLOAT16; + } else if (name == "float32" || name == "float") { + dtype = MetaDataType::kFLOAT32; + } else if (name == "float64" || name == "double") { + dtype = MetaDataType::kFLOAT64; + } else { + dtype = MetaDataType::kUNKNOWN; + } + return dtype; + } + + static bool IsListType(const std::string& dtype) { + int pos = dtype.find("list("); + return pos == 0; + } + + static const std::string GetEleType(const std::string& dtype) { + int pos = dtype.find("list("); + if (pos == 0) { + return dtype.substr(pos + 5, dtype.size() - 6); + } + return ""; + } +}; +""" + + +def _get_tvm_utils() -> str: + """Get the utils for tvm + + Returns + ------- + source: str + The plugin utils for tvm. + """ + + return """ +#ifdef PLUGIN_SUPPORT_TVM +using namespace tvm::relax; +using namespace tvm::runtime; +class TVMUtils { + public: + static void AttrFromPrim(const PrimValue& expr, std::string& target) { + ICHECK(expr->IsInstance()) << "Expr is not StringImm"; + target = Downcast(expr)->value; + } + + static void AttrFromPrim(const PrimValue& expr, bool& target) { + ICHECK(expr->value->IsInstance()) << "Expr value is not IntImm"; + target = Downcast(expr->value)->value; + } + + static void AttrFromPrim(const PrimValue& expr, int& target) { + ICHECK(expr->value->IsInstance()) << "Expr value is not IntImm"; + target = Downcast(expr->value)->value; + } + + static void AttrFromPrim(const PrimValue& expr, size_t& target) { + ICHECK(expr->value->IsInstance()) << "Expr value is not IntImm"; + target = Downcast(expr->value)->value; + } + + static void AttrFromPrim(const PrimValue& expr, long& target) { + ICHECK(expr->value->IsInstance()) << "Expr value is not IntImm"; + target = Downcast(expr->value)->value; + } + + static void AttrFromPrim(const PrimValue& expr, float& target) { + ICHECK(expr->value->IsInstance()) << "Expr value is not FloatImm"; + target = Downcast(expr->value)->value; + } + + static void AttrFromPrim(const PrimValue& expr, double& target) { + ICHECK(expr->value->IsInstance()) << "Expr value is not FloatImm"; + target = Downcast(expr->value)->value; + } + + template + static void AttrFromPrims(const Tuple& tuple, std::vector& target) { + for (size_t i = 0; i < tuple->fields.size(); i++) { + ICHECK(tuple->fields[i]->IsInstance()) << "Field is not PrimValue"; + AttrFromPrim(Downcast(tuple->fields[i]), target[i]); + } + } + + static void AttrFromArg(const TVMArgValue& arg, std::string& target) { + target = arg.operator std::string(); + } + + static void AttrFromArg(const TVMArgValue& arg, bool& target) { target = arg; } + + static void AttrFromArg(const TVMArgValue& arg, int& target) { target = arg; } + + static void AttrFromArg(const TVMArgValue& arg, size_t& target) { target = int(arg); } + + static void AttrFromArg(const TVMArgValue& arg, long& target) { target = int64_t(arg); } + + static void AttrFromArg(const TVMArgValue& arg, float& target) { target = double(arg); } + + static void AttrFromArg(const TVMArgValue& arg, double& target) { target = arg; } + + template + static void AttrFromArgs(const TVMArgs& args, size_t start, size_t num, std::vector& target) { + for (size_t i = 0; i < num; i++) { + AttrFromArg(args[start + i], target[i]); + } + } + + static MetaDataType ToMetaType(const DataType& dtype) { + MetaDataType meta_type; + if (dtype.code() == 0 && dtype.bits() == 8) { + meta_type = MetaDataType::kINT8; + } else if (dtype.code() == 0 && dtype.bits() == 16) { + meta_type = MetaDataType::kINT16; + } else if (dtype.code() == 0 && dtype.bits() == 32) { + meta_type = MetaDataType::kINT32; + } else if (dtype.code() == 0 && dtype.bits() == 64) { + meta_type = MetaDataType::kINT64; + } else if (dtype.code() == 1 && dtype.bits() == 8) { + meta_type = MetaDataType::kUINT8; + } else if (dtype.code() == 2 && dtype.bits() == 16) { + meta_type = MetaDataType::kFLOAT16; + } else if (dtype.code() == 2 && dtype.bits() == 32) { + meta_type = MetaDataType::kFLOAT32; + } else if (dtype.code() == 2 && dtype.bits() == 64) { + meta_type = MetaDataType::kFLOAT64; + } else { + meta_type = MetaDataType::kUNKNOWN; + } + return meta_type; + } + + static MetaDataType ToMetaType(const DLDataType& dtype) { + MetaDataType meta_type; + if (dtype.code == 0U && dtype.bits == 8) { + meta_type = MetaDataType::kINT8; + } else if (dtype.code == 0U && dtype.bits == 16) { + meta_type = MetaDataType::kINT16; + } else if (dtype.code == 0U && dtype.bits == 32) { + meta_type = MetaDataType::kINT32; + } else if (dtype.code == 0U && dtype.bits == 64) { + meta_type = MetaDataType::kINT64; + } else if (dtype.code == 1U && dtype.bits == 8) { + meta_type = MetaDataType::kUINT8; + } else if (dtype.code == 2U && dtype.bits == 16) { + meta_type = MetaDataType::kFLOAT16; + } else if (dtype.code == 2U && dtype.bits == 32) { + meta_type = MetaDataType::kFLOAT32; + } else if (dtype.code == 2U && dtype.bits == 64) { + meta_type = MetaDataType::kFLOAT64; + } else { + meta_type = MetaDataType::kUNKNOWN; + } + return meta_type; + } + + static MetaShape ToMetaShape(const Optional>& tvm_shape) { + if (tvm_shape.defined()) { + std::vector shape_data; + for (auto s : tvm_shape.value()) { + if (s->IsInstance()) { + shape_data.push_back(Downcast(s)->value); + } else { + shape_data.push_back(-1); + } + } + return MetaShape(shape_data); + } + return MetaShape(); + } + + static MetaShape ToMetaShape(DLTensor* tensor, bool as_data = true) { + std::vector dims; + if (as_data) { + assert(tensor->ndim == 1); + assert(TVMUtils::ToMetaType(tensor->dtype) == MetaDataType::kINT64); + int64_t* data_ptr = (int64_t*)tensor->data; + for (size_t i = 0; i < tensor->shape[0]; i++) { + dims.push_back(data_ptr[i]); + } + } else { + for (size_t i = 0; i < tensor->ndim; i++) { + dims.push_back(tensor->shape[i]); + } + } + return MetaShape(dims); + } + + static MetaTensor ToMetaTensor(const Expr& expr, + const LayoutDecision& layout_dec = LayoutDecision()) { + const auto* sinfo = GetStructInfoAs(expr); + if (layout_dec.defined() && layout_dec->layout.defined()) { + const auto& layout = MetaLayout(layout_dec->layout.name()); + return MetaTensor(ToMetaShape(sinfo->GetShape()), ToMetaType(sinfo->dtype), layout); + } + const auto& layout = MetaLayout(SpanUtils::GetAttr(expr->span, "layout")); + return MetaTensor(ToMetaShape(sinfo->GetShape()), ToMetaType(sinfo->dtype), layout); + } + + template + static DataTensor ToDataTensor(DLTensor* tensor, bool read_only) { + if (read_only) { + return DataTensor(ToMetaShape(tensor, false), ToMetaType(tensor->dtype), MetaLayout(), + (const T*)(tensor->data)); + } else { + return DataTensor(ToMetaShape(tensor, false), ToMetaType(tensor->dtype), MetaLayout(), + (T*)(tensor->data)); + } + } + + static DataType ToTVMType(const MetaDataType& dtype) { + DataType tvm_type; + if (dtype == MetaDataType::kINT8) { + tvm_type = DataType::Int(8); + } else if (dtype == MetaDataType::kINT16) { + tvm_type = DataType::Int(16); + } else if (dtype == MetaDataType::kINT32) { + tvm_type = DataType::Int(32); + } else if (dtype == MetaDataType::kINT64) { + tvm_type = DataType::Int(64); + } else if (dtype == MetaDataType::kFLOAT16) { + tvm_type = DataType::Float(16); + } else if (dtype == MetaDataType::kFLOAT32) { + tvm_type = DataType::Float(32); + } else if (dtype == MetaDataType::kFLOAT64) { + tvm_type = DataType::Float(64); + } else { + throw std::runtime_error("Unsupported type"); + } + return tvm_type; + } + + static DataType ToTVMType(const std::string& dtype) { + return ToTVMType(DataUtils::ToMetaType(dtype)); + } + + static Array ToTVMShape(const MetaShape& meta_shape) { + Array tvm_shape; + for (size_t i = 0; i < meta_shape.ndim(); i++) { + auto dim = meta_shape.DimAt(i); + if (dim == -1) { + tvm_shape.push_back(tir::Any()); + } else { + tvm_shape.push_back(Integer(dim)); + } + } + return tvm_shape; + } + + static void FillDLShape(const MetaShape& shape, DLTensor* data) { + auto shape_data = static_cast(data->data); + for (size_t i = 0; i < shape.ndim(); i++) { + shape_data[i] = shape.DimAt(i); + } + } + + static TensorStructInfo ToTensorStructInfo(const MetaTensor& tensor, + const Optional& device) { + const auto& t_shape = ToTVMShape(tensor.shape()); + const auto& t_type = ToTVMType(tensor.data_type()); + return TensorStructInfo(ShapeExpr(t_shape), t_type, device); + } + + static TensorStructInfo ToTensorStructInfo(const MetaTensor& tensor, const Expr& expr) { + const auto* sinfo = GetStructInfoAs(expr); + return ToTensorStructInfo(tensor, sinfo->vdevice); + } + + static bool OnDevice(DLTensor* tensor, DLDeviceType device) { + return tensor->device.device_type == device; + } + + static void CheckDevice(DLTensor* tensor, DLDeviceType device) { + ICHECK_EQ(tensor->device.device_type, device); + } + + static Device DefaultCPU() { + Device cpu_dev{kDLCPU, 0}; + return cpu_dev; + } + + static Device DefaultCUDA() { + Device cuda_dev{kDLCUDA, 0}; + return cuda_dev; + } +}; +#endif // PLUGIN_SUPPORT_TVM +""" + + +def _get_torch_utils() -> str: + """Get the utils for torch + + Returns + ------- + source: str + The plugin utils for torch. + """ + + return """ +#ifdef PLUGIN_SUPPORT_TORCH +class TorchUtils { + public: + static MetaDataType ToMetaType(const torch::ScalarType& dtype) { + MetaDataType meta_type; + if (dtype == torch::kChar) { + meta_type = MetaDataType::kINT8; + } else if (dtype == torch::kInt) { + meta_type = MetaDataType::kINT32; + } else if (dtype == torch::kInt64) { + meta_type = MetaDataType::kINT64; + } else if (dtype == torch::kLong) { + meta_type = MetaDataType::kINT64; + } else if (dtype == torch::kFloat16) { + meta_type = MetaDataType::kFLOAT16; + } else if (dtype == torch::kFloat) { + meta_type = MetaDataType::kFLOAT32; + } else if (dtype == torch::kDouble) { + meta_type = MetaDataType::kFLOAT64; + } else { + meta_type = MetaDataType::kUNKNOWN; + } + return meta_type; + } + + static MetaShape ToMetaShape(const torch::Tensor& tensor) { + std::vector shape_data; + for (size_t idx = 0; idx < tensor.dim(); idx++) { + shape_data.push_back(tensor.size(idx)); + } + return MetaShape(shape_data); + } + + static MetaTensor ToMetaTensor(const torch::Tensor& tensor, + const MetaLayout& layout = MetaLayout()) { + return MetaTensor(ToMetaShape(tensor), ToMetaType(tensor.scalar_type()), layout); + } + + template + static DataTensor ToDataTensor(const torch::Tensor& tensor, const MetaTensor& meta, + bool read_only) { + if (read_only) { + return DataTensor(meta.shape(), meta.data_type(), meta.layout(), + (const T*)(tensor.data_ptr())); + } else { + return DataTensor(meta.shape(), meta.data_type(), meta.layout(), (T*)(tensor.data_ptr())); + } + } + + static torch::ScalarType ToTorchType(const MetaDataType& dtype) { + torch::ScalarType torch_type; + if (dtype == MetaDataType::kINT8) { + torch_type = torch::kChar; + } else if (dtype == MetaDataType::kINT32) { + torch_type = torch::kInt; + } else if (dtype == MetaDataType::kINT64) { + torch_type = torch::kInt64; + } else if (dtype == MetaDataType::kFLOAT16) { + torch_type = torch::kFloat16; + } else if (dtype == MetaDataType::kFLOAT32) { + torch_type = torch::kFloat; + } else if (dtype == MetaDataType::kFLOAT64) { + torch_type = torch::kDouble; + } else { + throw std::runtime_error("Unsupported type"); + } + return torch_type; + } + + static torch::ScalarType ToTorchType(const std::string& dtype) { + return ToTorchType(DataUtils::ToMetaType(dtype)); + } + + static torch::Device ToTorchDevice(const std::string& device) { + if (device == "cpu") { + return torch::Device(torch::kCPU); + } + if (device == "cuda") { + return torch::Device(torch::kCUDA); + } + return torch::Device(torch::kCPU); + } + + static torch::Tensor MallocTorchTensor(const MetaTensor& tensor, const torch::Device& device) { + auto t_type = ToTorchType(tensor.data_type()); + auto opt = torch::TensorOptions().dtype(t_type).device(device); + return torch::zeros(tensor.meta_shape(), opt); + } +}; +#endif // PLUGIN_SUPPORT_TORCH +""" + + +def _get_tensorrt_utils() -> str: + """Get the utils for tensorrt + + Returns + ------- + source: str + The plugin utils for tensorrt. + """ + + return """ +#ifdef PLUGIN_SUPPORT_TENSORRT +using namespace nvinfer1; + +#ifndef TRT_VERSION_GE +#define TRT_VERSION_GE(major, minor, patch) \\ + ((TRT_MAJOR > major) || (TRT_MAJOR == major && TRT_MINOR > minor) || \\ + (TRT_MAJOR == major && TRT_MINOR == minor && TRT_PATCH >= patch)) +#endif + +class TRTUtils { + public: + template + static void ValToBuffer(char*& buffer, const T& val) { + *reinterpret_cast(buffer) = val; + buffer += sizeof(T); + } + + static void ValToBuffer(char*& buffer, const std::string& val) { + *reinterpret_cast(buffer) = val.size(); + buffer += sizeof(size_t); + val.copy(buffer, val.size()); + buffer += sizeof(char) * val.size(); + } + + template + static void ValToBuffer(char*& buffer, const std::vector& val) { + ValToBuffer(buffer, val.size()); + for (auto e : val) { + ValToBuffer(buffer, e); + } + } + + template + static void ValFromBuffer(const char*& buffer, T& val) { + val = *reinterpret_cast(buffer); + buffer += sizeof(T); + } + + static void ValFromBuffer(const char*& buffer, std::string& val) { + auto size = *reinterpret_cast(buffer); + buffer += sizeof(size_t); + val = std::string(reinterpret_cast(buffer), size); + buffer += sizeof(char) * size; + } + + template + static void ValFromBuffer(const char*& buffer, std::vector& val) { + size_t size; + ValFromBuffer(buffer, size); + val.resize(size); + for (size_t i = 0; i < size; i++) { + ValFromBuffer(buffer, val[i]); + } + } + + static PluginFieldType ToFieldType(const std::string& dtype) { + PluginFieldType field_type; + if (dtype == "char" || dtype == "uint8" || dtype == "string") { + field_type = PluginFieldType::kCHAR; + } else if (dtype == "int8") { + field_type = PluginFieldType::kINT8; + } else if (dtype == "int16") { + field_type = PluginFieldType::kINT16; + } else if (dtype == "int" || dtype == "int32") { + field_type = PluginFieldType::kINT32; + } else if (dtype == "float16" || dtype == "half") { + field_type = PluginFieldType::kFLOAT16; + } else if (dtype == "float32" || dtype == "float") { + field_type = PluginFieldType::kFLOAT32; + } else if (dtype == "float64" || dtype == "double") { + field_type = PluginFieldType::kFLOAT64; + } else { + field_type = PluginFieldType::kUNKNOWN; + } + return field_type; + } + + static const PluginField ToField(const std::string& name, const std::string& dtype) { + const auto& ele_type = DataUtils::GetEleType(dtype); + if (ele_type.size() == 0) { + return PluginField(name.c_str(), nullptr, ToFieldType(dtype), 1); + } + return PluginField(name.c_str(), nullptr, ToFieldType(ele_type), 11); + } + + static void FromField(const PluginField& field, std::string& val) { + assert(field.type == PluginFieldType::kCHAR); + const char* data = static_cast(field.data); + val = data; + } + + static void FromField(const PluginField& field, bool& val) { + assert(field.type == PluginFieldType::kINT32); + int int_val = *(static_cast(field.data)); + val = int_val == 0 ? false : true; + } + + static void FromField(const PluginField& field, int& val) { + assert(field.type == PluginFieldType::kINT32); + val = *(static_cast(field.data)); + } + + static void FromField(const PluginField& field, size_t& val) { + assert(field.type == PluginFieldType::kINT32); + val = *(static_cast(field.data)); + } + + static void FromField(const PluginField& field, long& val) { + assert(field.type == PluginFieldType::kINT32); + val = *(static_cast(field.data)); + } + + static void FromField(const PluginField& field, float& val) { + assert(field.type == PluginFieldType::kFLOAT32); + val = *(static_cast(field.data)); + } + + static void FromField(const PluginField& field, double& val) { + assert(field.type == PluginFieldType::kFLOAT64); + val = *(static_cast(field.data)); + } + + static MetaDataType ToMetaType(const DataType& dtype) { + MetaDataType meta_type; + if (dtype == DataType::kINT8) { + meta_type = MetaDataType::kINT8; + } else if (dtype == DataType::kINT32) { + meta_type = MetaDataType::kINT32; + } else if (dtype == DataType::kHALF) { + meta_type = MetaDataType::kFLOAT16; + } else if (dtype == DataType::kFLOAT) { + meta_type = MetaDataType::kFLOAT32; + } else { + meta_type = MetaDataType::kUNKNOWN; + } + return meta_type; + } + + static MetaShape ToMetaShape(const Dims& trt_dims, bool dynamic = false) { + std::vector dims; + if (!dynamic) { + dims.push_back(1); + } + for (size_t idx = 0; idx < trt_dims.nbDims; idx++) { + dims.push_back(trt_dims.d[idx]); + } + return MetaShape(dims); + } + + static MetaShape ToMetaShape(const DimsExprs& trt_dims) { + std::vector dims; + for (size_t idx = 0; idx < trt_dims.nbDims; idx++) { + assert(trt_dims.d[idx]->isConstant()); + dims.push_back(trt_dims.d[idx]->getConstantValue()); + } + return MetaShape(dims); + } + + static MetaShape ToMetaShape(const PluginTensorDesc& desc) { + return ToMetaShape(desc.dims, true); + } + + static MetaShape ToMetaShape(const DynamicPluginTensorDesc& desc) { + return ToMetaShape(desc.desc); + } + + static MetaTensor ToMetaTensor(const Dims& dims, const DataType& dtype, const std::string& layout, + bool dynamic = false) { + return MetaTensor(ToMetaShape(dims, dynamic), ToMetaType(dtype), MetaLayout(layout)); + } + + static MetaTensor ToMetaTensor(const DimsExprs& dims, const DataType& dtype, + const std::string& layout) { + return MetaTensor(ToMetaShape(dims), ToMetaType(dtype), MetaLayout(layout)); + } + + static MetaTensor ToMetaTensor(const PluginTensorDesc& desc, const std::string& layout) { + return ToMetaTensor(desc.dims, desc.type, layout, true); + } + + static MetaTensor ToMetaTensor(const DynamicPluginTensorDesc& desc, const std::string& layout) { + return ToMetaTensor(desc.desc, layout); + } + + static DataType ToDataType(const MetaDataType& dtype) { + DataType data_type; + if (dtype == MetaDataType::kINT8) { + data_type = DataType::kINT8; + } else if (dtype == MetaDataType::kINT32) { + data_type = DataType::kINT32; + } else if (dtype == MetaDataType::kFLOAT16) { + data_type = DataType::kHALF; + } else if (dtype == MetaDataType::kFLOAT32) { + data_type = DataType::kFLOAT; + } else { + data_type = DataType::kFLOAT; + } + return data_type; + } + + static DataType ToDataType(const std::string& dtype) { + return ToDataType(DataUtils::ToMetaType(dtype)); + } + + static Dims ToDims(const MetaShape& meta_shape, bool dynamic = false) { + std::vector int_dims; + if (dynamic) { + int_dims.push_back(meta_shape.DimAt(0)); + } + for (size_t i = 1; i < meta_shape.ndim(); i++) { + int_dims.push_back(meta_shape.DimAt(i)); + } + Dims dims{int(int_dims.size())}; + for (size_t i = 0; i < int_dims.size(); i++) { + dims.d[i] = int_dims[i]; + } + return dims; + } + + static DimsExprs ToDimsExprs(const MetaShape& meta_shape, IExprBuilder& builder) { + std::vector int_dims; + for (size_t i = 0; i < meta_shape.ndim(); i++) { + int_dims.push_back(meta_shape.DimAt(i)); + } + DimsExprs dims{int(int_dims.size())}; + for (size_t i = 0; i < int_dims.size(); i++) { + dims.d[i] = builder.constant(int_dims[i]); + } + return dims; + } + + static const MetaShape SetBatch(const MetaTensor& tensor, int batch_size) { + MetaShape shape = tensor.shape(); + int batch = tensor.AxisOf("N"); + if (batch < 0) { + batch = 0; + } + shape.SetDim(batch, batch_size); + return shape; + } + + template + static DataTensor ToDataTensor(const MetaTensor& tensor, int batch_size, const void* data) { + const auto& shape = SetBatch(tensor, batch_size); + return DataTensor(shape, tensor.data_type(), tensor.layout(), (const T*)(data)); + } + + template + static DataTensor ToDataTensor(const MetaTensor& tensor, int batch_size, void* data) { + const auto& shape = SetBatch(tensor, batch_size); + return DataTensor(shape, tensor.data_type(), tensor.layout(), (const T*)(data)); + } + + template + static DataTensor ToDataTensor(const MetaTensor& tensor, const PluginTensorDesc& desc, + const void* data) { + return DataTensor(ToMetaShape(desc), ToMetaType(desc.type), tensor.layout(), + (const T*)(data)); + } + + template + static DataTensor ToDataTensor(const MetaTensor& tensor, const PluginTensorDesc& desc, + void* data) { + return DataTensor(ToMetaShape(desc), ToMetaType(desc.type), tensor.layout(), (T*)(data)); + } +}; +#endif // PLUGIN_SUPPORT_TENSORRT +""" + + +def get_plugin_utils_h_code() -> str: + """Create plugin utils header file codes + + Returns + ------- + source: str + The plugin utils header source. + """ + + code = """#ifndef TVM_CONTRIB_MSC_UTILS_PLUGIN_UTILS_H_ +#define TVM_CONTRIB_MSC_UTILS_PLUGIN_UTILS_H_ + +#include +#include + +#include +#include +#include +#include +#include + +#include "plugin_base.h" + +#ifdef PLUGIN_ENABLE_CUDA +#include +#include +#endif // PLUGIN_ENABLE_CUDA + +#ifdef PLUGIN_SUPPORT_TVM +#include + +#include "tvm/../../src/contrib/msc/core/transform/layout_utils.h" +#include "tvm/../../src/contrib/msc/core/utils.h" +#ifdef PLUGIN_ENABLE_CUDA +#include "tvm/../../src/runtime/cuda/cuda_common.h" +#endif // PLUGIN_ENABLE_CUDA +#endif // PLUGIN_SUPPORT_TVM + +#ifdef PLUGIN_SUPPORT_TORCH +#include +#include +#ifdef PLUGIN_ENABLE_CUDA +#include +#endif // PLUGIN_ENABLE_CUDA +#endif // PLUGIN_SUPPORT_TORCH + +#ifdef PLUGIN_SUPPORT_TENSORRT +#include "NvInfer.h" +#endif // PLUGIN_SUPPORT_TENSORRT + +namespace tvm { +namespace contrib { +namespace msc { +namespace plugin { + +""" + code += _get_common_utils() + code += _get_tvm_utils() + code += _get_torch_utils() + code += _get_tensorrt_utils() + code += """ +} // namespace plugin +} // namespace msc +} // namespace contrib +} // namespace tvm +#endif // TVM_CONTRIB_MSC_UTILS_PLUGIN_UTILS_H_ +""" + return code + + +def get_plugin_sources() -> Dict[str, str]: + """Create base sources for plugin codegen + + Returns + ------- + sources: dict + The base utils sources. + """ + + return {"plugin_base.h": get_plugin_base_h_code(), "plugin_utils.h": get_plugin_utils_h_code()} diff --git a/python/tvm/contrib/msc/plugin/op/__init__.py b/python/tvm/contrib/msc/plugin/op/__init__.py new file mode 100644 index 000000000000..6b306c8c1f5b --- /dev/null +++ b/python/tvm/contrib/msc/plugin/op/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.plugin.op""" diff --git a/python/tvm/contrib/msc/plugin/op/_ffi_api.py b/python/tvm/contrib/msc/plugin/op/_ffi_api.py new file mode 100644 index 000000000000..2111e11227a1 --- /dev/null +++ b/python/tvm/contrib/msc/plugin/op/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.plugin.op._ffi_api""" + +import tvm._ffi + +tvm._ffi._init_api("msc.plugin.op", __name__) diff --git a/python/tvm/contrib/msc/plugin/register.py b/python/tvm/contrib/msc/plugin/register.py new file mode 100644 index 000000000000..2b1ec19d6ecf --- /dev/null +++ b/python/tvm/contrib/msc/plugin/register.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.plugin.register""" + +import os +from typing import Dict + +import tvm +from tvm.contrib.msc.core import utils as msc_utils +from tvm.contrib.msc.core import _ffi_api + + +def register_plugin( + name: str, plugin: dict, externs_dir: msc_utils.MSCDirectory = None +) -> Dict[str, str]: + """Register a plugin + + Parameters + ---------- + name: str + The name of the plugin. + plugin: dict + The define of a plugin. + externs_dir: MSCDirectory + The extern sources folder. + + Returns + ------- + depend_files: dict + The depend file paths. + """ + + plugin = {"name": name, **msc_utils.load_dict(plugin)} + assert "externs" in plugin, "externs are needed to build plugin" + # check device compute + remove_externs = set() + for extern in plugin["externs"]: + if extern == "cuda_compute" and not tvm.cuda().exist: + remove_externs.add(extern) + if remove_externs: + plugin["externs"] = {k: v for k, v in plugin["externs"].items() if k not in remove_externs} + externs = plugin["externs"] + + def _check_file(info: dict, key: str) -> str: + if key not in info: + return None + file_path = info[key] + if os.path.abspath(file_path) != file_path: + assert externs_dir, "externs_dir is need to find file " + str(file_path) + file_path = externs_dir.relpath(file_path) + assert os.path.isfile(file_path), "Can not find externs file " + str(file_path) + info[key] = os.path.basename(file_path) + return file_path + + # find depend files + extern_sources, extern_libs = {}, {} + for info in externs.values(): + for key in ["header", "source"]: + file_path = _check_file(info, key) + if file_path: + extern_sources[os.path.basename(file_path)] = file_path + file_path = _check_file(info, "lib") + if file_path: + extern_libs[os.path.basename(file_path)] = file_path + _ffi_api.RegisterPlugin(name, msc_utils.dump_dict(plugin)) + # remove needless keys + for key in ["support_dtypes", "externs"]: + plugin.pop(key) + plugin["inputs"] = [{"name": i["name"]} for i in plugin["inputs"]] + plugin["outputs"] = [{"name": o["name"]} for o in plugin["outputs"]] + return extern_sources, extern_libs, plugin diff --git a/python/tvm/contrib/msc/plugin/utils.py b/python/tvm/contrib/msc/plugin/utils.py new file mode 100644 index 000000000000..770ad44b12ec --- /dev/null +++ b/python/tvm/contrib/msc/plugin/utils.py @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.plugin.utils""" + +import os +from typing import Any + +from tvm import relax +from tvm import tir +from tvm.contrib.msc.core import utils as msc_utils + + +def to_expr(value: Any) -> relax.Expr: + """Change value to expr + + Parameters + ---------- + value: + The value with python type. + + Returns + ------- + expr: relax.Expr + The relax Expr. + """ + + if isinstance(value, (bool, int)): + value = tir.IntImm("int64", value) + expr = relax.PrimValue(value) + elif isinstance(value, float): + value = tir.FloatImm("float64", value) + expr = relax.PrimValue(value) + elif isinstance(value, str): + expr = relax.StringImm(value) + elif isinstance(value, (list, tuple)): + expr = relax.Tuple([to_expr(v) for v in value]) + else: + raise TypeError(f"Unsupported input type: {type(value)}") + return expr + + +def export_plugins(plugins: dict, folder: msc_utils.MSCDirectory) -> dict: + """Export the plugins + + Parameters + ---------- + plugins: dict + The plugins. + folder: MSCDirectory + The export folder. + + Returns + ------- + info: dict + The loadable plugins info. + """ + + if not plugins: + return {} + info = {} + for name, plugin in plugins.items(): + with folder.create_dir(name) as sub_folder: + info[name] = sub_folder.path + plugin.export(info[name]) + return info + + +def load_plugins(info: dict) -> dict: + """Load the plugins + + Parameters + ---------- + info: dict + The plugins info. + + Returns + ------- + plugins: dict + The plugins. + """ + + if not info: + return {} + plugins = {} + for name, plugin in info.items(): + if isinstance(plugin, str): + manager_file = os.path.join(plugin, "manager.py") + assert os.path.isfile(manager_file), "Can not find manager file for plugin: " + str( + manager_file + ) + manager_cls = msc_utils.load_callable(manager_file + ":PluginManager") + plugins[name] = manager_cls(plugin) + else: + plugins[name] = plugin + return plugins diff --git a/src/contrib/msc/core/ir/plugin.cc b/src/contrib/msc/core/ir/plugin.cc new file mode 100644 index 000000000000..d34972639a7b --- /dev/null +++ b/src/contrib/msc/core/ir/plugin.cc @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/core/ir/plugin.cc + */ + +#include "plugin.h" + +#include +#include +#include +#include +#include + +namespace tvm { +namespace contrib { +namespace msc { + +PluginAttr::PluginAttr(const String& name, const String& type, const String& default_value, + const String& describe) { + ObjectPtr n = make_object(); + n->name = std::move(name); + n->type = std::move(type); + n->default_value = std::move(default_value); + n->describe = std::move(describe); + data_ = std::move(n); +} + +PluginAttr::PluginAttr(const JsonPluginAttr& j_attr) { + ObjectPtr n = make_object(); + n->FromJson(j_attr); + data_ = std::move(n); +} + +PluginAttr::PluginAttr(const std::string& json_str) { + ObjectPtr n = make_object(); + n->FromJson(json_str); + data_ = std::move(n); +} + +const JsonPluginAttr PluginAttrNode::ToJson() const { + JsonPluginAttr j_attr; + j_attr.name = name; + j_attr.type = type; + j_attr.default_value = default_value; + j_attr.describe = describe; + return j_attr; +} + +void PluginAttrNode::FromJson(const JsonPluginAttr& j_attr) { + name = j_attr.name; + type = j_attr.type; + default_value = j_attr.default_value; + describe = j_attr.describe; +} + +void PluginAttrNode::FromJson(const std::string& json_str) { + std::istringstream is(json_str); + dmlc::JSONReader reader(&is); + JsonPluginAttr j_attr; + reader.Read(&j_attr); + FromJson(j_attr); +} + +PluginTensor::PluginTensor(const String& name, const String& dtype, const Integer& ndim, + const String& device, const String& describe) { + ObjectPtr n = make_object(); + n->name = std::move(name); + n->dtype = std::move(dtype); + n->ndim = std::move(ndim); + n->device = std::move(device); + n->describe = std::move(describe); + data_ = std::move(n); +} + +PluginTensor::PluginTensor(const JsonPluginTensor& j_tensor) { + ObjectPtr n = make_object(); + n->FromJson(j_tensor); + data_ = std::move(n); +} + +PluginTensor::PluginTensor(const std::string& json_str) { + ObjectPtr n = make_object(); + n->FromJson(json_str); + data_ = std::move(n); +} + +const JsonPluginTensor PluginTensorNode::ToJson() const { + JsonPluginTensor j_tensor; + j_tensor.name = name; + j_tensor.dtype = dtype; + j_tensor.ndim = ndim->value; + j_tensor.device = device; + j_tensor.describe = describe; + return j_tensor; +} + +void PluginTensorNode::FromJson(const JsonPluginTensor& j_tensor) { + name = j_tensor.name; + dtype = j_tensor.dtype; + ndim = Integer(j_tensor.ndim); + device = j_tensor.device; + describe = j_tensor.describe; +} + +void PluginTensorNode::FromJson(const std::string& json_str) { + std::istringstream is(json_str); + dmlc::JSONReader reader(&is); + JsonPluginTensor j_tensor; + reader.Read(&j_tensor); + FromJson(j_tensor); +} + +PluginExtern::PluginExtern(const String& name, const String& header, const String& source, + const String& lib, const String& describe) { + ObjectPtr n = make_object(); + n->name = std::move(name); + n->header = std::move(header); + n->source = std::move(source); + n->lib = std::move(lib); + n->describe = std::move(describe); + data_ = std::move(n); +} + +PluginExtern::PluginExtern(const JsonPluginExtern& j_extern) { + ObjectPtr n = make_object(); + n->FromJson(j_extern); + data_ = std::move(n); +} + +PluginExtern::PluginExtern(const std::string& json_str) { + ObjectPtr n = make_object(); + n->FromJson(json_str); + data_ = std::move(n); +} + +const JsonPluginExtern PluginExternNode::ToJson() const { + JsonPluginExtern j_extern; + j_extern.name = name; + j_extern.header = header; + j_extern.source = source; + j_extern.lib = lib; + j_extern.describe = describe; + return j_extern; +} + +void PluginExternNode::FromJson(const JsonPluginExtern& j_extern) { + name = j_extern.name; + header = j_extern.header; + source = j_extern.source; + lib = j_extern.lib; + describe = j_extern.describe; +} + +void PluginExternNode::FromJson(const std::string& json_str) { + std::istringstream is(json_str); + dmlc::JSONReader reader(&is); + JsonPluginExtern j_extern; + reader.Read(&j_extern); + FromJson(j_extern); +} + +Plugin::Plugin(const String& name, const String& version, const String& describe, + const Array& attrs, const Array& inputs, + const Array& outputs, const Array& buffers, + const Map& externs, + const Map>& support_dtypes, + const Map& options) { + ObjectPtr n = make_object(); + n->name = std::move(name); + n->version = std::move(version); + n->describe = std::move(describe); + n->attrs = std::move(attrs); + n->inputs = std::move(inputs); + n->outputs = std::move(outputs); + n->buffers = std::move(buffers); + n->externs = std::move(externs); + n->support_dtypes = std::move(support_dtypes); + n->options = std::move(options); + data_ = std::move(n); +} + +Plugin::Plugin(const JsonPlugin& j_plugin) { + ObjectPtr n = make_object(); + n->FromJson(j_plugin); + data_ = std::move(n); +} + +Plugin::Plugin(const std::string& json_str) { + ObjectPtr n = make_object(); + n->FromJson(json_str); + data_ = std::move(n); +} + +const JsonPlugin PluginNode::ToJson() const { + JsonPlugin j_plugin; + j_plugin.name = name; + j_plugin.version = version; + j_plugin.describe = describe; + for (const auto& a : attrs) { + j_plugin.attrs.push_back(a->ToJson()); + } + for (const auto& t : inputs) { + j_plugin.inputs.push_back(t->ToJson()); + } + for (const auto& t : outputs) { + j_plugin.inputs.push_back(t->ToJson()); + } + for (const auto& t : buffers) { + j_plugin.inputs.push_back(t->ToJson()); + } + for (const auto& pair : externs) { + j_plugin.externs[pair.first] = pair.second->ToJson(); + } + for (const auto& pair : support_dtypes) { + std::vector dtypes; + for (const auto& d : pair.second) { + dtypes.push_back(d); + } + j_plugin.support_dtypes[pair.first] = dtypes; + } + for (const auto& pair : options) { + j_plugin.options[pair.first] = pair.second; + } + return j_plugin; +} + +void PluginNode::FromJson(const JsonPlugin& j_plugin) { + name = j_plugin.name; + version = j_plugin.version; + describe = j_plugin.describe; + for (const auto& a : j_plugin.attrs) { + attrs.push_back(PluginAttr(a)); + } + for (const auto& t : j_plugin.inputs) { + inputs.push_back(PluginTensor(t)); + } + for (const auto& t : j_plugin.outputs) { + outputs.push_back(PluginTensor(t)); + } + for (const auto& t : j_plugin.buffers) { + buffers.push_back(PluginTensor(t)); + } + for (const auto& pair : j_plugin.externs) { + externs.Set(pair.first, PluginExtern(pair.second)); + } + for (const auto& pair : j_plugin.support_dtypes) { + Array dtypes; + for (const auto& d : pair.second) { + dtypes.push_back(d); + } + support_dtypes.Set(pair.first, dtypes); + } + for (const auto& pair : j_plugin.options) { + options.Set(pair.first, pair.second); + } +} + +void PluginNode::FromJson(const std::string& json_str) { + std::istringstream is(json_str); + dmlc::JSONReader reader(&is); + JsonPlugin j_plugin; + reader.Read(&j_plugin); + FromJson(j_plugin); +} + +int PluginNode::FindDtypeRefIdx(const PluginTensor& tensor) const { + for (size_t i = 0; i < inputs.size(); i++) { + if (inputs[i]->dtype == tensor->dtype) { + return i; + } + } + return -1; +} + +int PluginNode::FindDeviceRefIdx(const PluginTensor& tensor) const { + for (size_t i = 0; i < inputs.size(); i++) { + if (inputs[i]->device == tensor->device) { + return i; + } + } + return -1; +} + +const Array ListPluginNames() { return PluginRegistry::Global()->ListAllNames(); } + +const Plugin GetPlugin(const String& name) { return PluginRegistry::Global()->Get(name); } + +bool IsPlugin(const String& name) { return PluginRegistry::Global()->Registered(name); } + +TVM_REGISTER_GLOBAL("msc.core.RegisterPlugin") + .set_body_typed([](const String& name, const String& json_str) { + PluginRegistry::Global()->Register(name, json_str); + }); + +TVM_REGISTER_GLOBAL("msc.core.ListPluginNames").set_body_typed([]() -> Array { + return ListPluginNames(); +}); + +TVM_REGISTER_GLOBAL("msc.core.GetPlugin").set_body_typed([](const String& name) -> Plugin { + return GetPlugin(name); +}); + +TVM_REGISTER_GLOBAL("msc.core.IsPlugin").set_body_typed([](const String& name) -> Bool { + return Bool(IsPlugin(name)); +}); + +} // namespace msc +} // namespace contrib +} // namespace tvm diff --git a/src/contrib/msc/core/ir/plugin.h b/src/contrib/msc/core/ir/plugin.h new file mode 100644 index 000000000000..dc6f3be68dc4 --- /dev/null +++ b/src/contrib/msc/core/ir/plugin.h @@ -0,0 +1,686 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/core/ir/plugin.h + * \brief Plugin describe for msc. + */ +#ifndef TVM_CONTRIB_MSC_CORE_IR_PLUGIN_H_ +#define TVM_CONTRIB_MSC_CORE_IR_PLUGIN_H_ + +#include +#include + +#include +#include +#include + +#include "../../../../node/attr_registry.h" +#include "../utils.h" + +namespace tvm { +namespace contrib { +namespace msc { + +/*! + * \brief Json serialize and deserialize for Plugin Attribute. + */ +struct JsonPluginAttr { + std::string name; + std::string type; + std::string default_value; + std::string describe; + + void Save(dmlc::JSONWriter* writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("name", name); + writer->WriteObjectKeyValue("type", type); + writer->WriteObjectKeyValue("default_value", default_value); + writer->WriteObjectKeyValue("describe", describe); + writer->EndObject(); + } + + void Load(dmlc::JSONReader* reader) { + int bitmask = 0; + std::string key; + reader->BeginObject(); + while (reader->NextObjectItem(&key)) { + if (key == "name") { + reader->Read(&name); + bitmask |= 1; + } else if (key == "type") { + reader->Read(&type); + bitmask |= 2; + } else if (key == "default_value") { + reader->Read(&default_value); + } else if (key == "describe") { + reader->Read(&describe); + } + } + ICHECK_EQ(bitmask, 1 | 2) << "name and type should be given for plugin attr"; + if (describe.size() == 0) { + describe = "Plugin attribute " + name + "(" + type + ")"; + } + } +}; + +/*! + * \brief Json serialize and deserialize for Plugin Tensor. + */ +struct JsonPluginTensor { + std::string name; + std::string dtype; + int64_t ndim{-1}; + std::string device{"default"}; + std::string describe; + + void Save(dmlc::JSONWriter* writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("name", name); + writer->WriteObjectKeyValue("dtype", dtype); + writer->WriteObjectKeyValue("ndim", ndim); + writer->WriteObjectKeyValue("device", device); + writer->WriteObjectKeyValue("describe", describe); + writer->EndObject(); + } + + void Load(dmlc::JSONReader* reader) { + int bitmask = 0; + std::string key; + reader->BeginObject(); + while (reader->NextObjectItem(&key)) { + if (key == "name") { + reader->Read(&name); + bitmask |= 1; + } else if (key == "dtype") { + reader->Read(&dtype); + } else if (key == "ndim") { + reader->Read(&ndim); + } else if (key == "device") { + reader->Read(&device); + } else if (key == "describe") { + reader->Read(&describe); + } + } + ICHECK_EQ(bitmask, 1) << "name should be given for plugin tensor"; + if (describe.size() == 0) { + describe = "Plugin tensor " + name + "(" + dtype + " on " + device + ")"; + } + } +}; + +/*! + * \brief Json serialize and deserialize for Plugin Extern. + */ +struct JsonPluginExtern { + std::string name; + std::string header; + std::string source; + std::string lib; + std::string describe; + + void Save(dmlc::JSONWriter* writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("name", name); + writer->WriteObjectKeyValue("header", header); + writer->WriteObjectKeyValue("source", source); + writer->WriteObjectKeyValue("lib", lib); + writer->WriteObjectKeyValue("describe", describe); + writer->EndObject(); + } + + void Load(dmlc::JSONReader* reader) { + int bitmask = 0; + std::string key; + reader->BeginObject(); + while (reader->NextObjectItem(&key)) { + if (key == "name") { + reader->Read(&name); + bitmask |= 1; + } else if (key == "header") { + reader->Read(&header); + } else if (key == "source") { + reader->Read(&source); + } else if (key == "lib") { + reader->Read(&lib); + } else if (key == "describe") { + reader->Read(&describe); + } + } + ICHECK_EQ(bitmask, 1) << "name should be given for plugin extern"; + if (describe.size() == 0) { + describe = "Plugin function " + name + "(from " + header + ")"; + } + } +}; + +/*! + * \brief Json serialize and deserialize for Plugin. + */ +struct JsonPlugin { + std::string name; + std::string version; + std::string describe; + std::vector attrs; + std::vector inputs; + std::vector outputs; + std::vector buffers; + std::unordered_map externs; + std::unordered_map> support_dtypes; + std::unordered_map options; + + void Save(dmlc::JSONWriter* writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("name", name); + writer->WriteObjectKeyValue("version", version); + writer->WriteObjectKeyValue("describe", describe); + writer->WriteObjectKeyValue("attrs", attrs); + writer->WriteObjectKeyValue("inputs", inputs); + writer->WriteObjectKeyValue("outputs", outputs); + writer->WriteObjectKeyValue("buffers", buffers); + writer->WriteObjectKeyValue("externs", externs); + writer->WriteObjectKeyValue("support_dtypes", support_dtypes); + writer->WriteObjectKeyValue("options", options); + writer->EndObject(); + } + + void Load(dmlc::JSONReader* reader) { + int bitmask = 0; + std::string key; + reader->BeginObject(); + while (reader->NextObjectItem(&key)) { + if (key == "name") { + reader->Read(&name); + bitmask |= 1; + } else if (key == "version") { + reader->Read(&version); + } else if (key == "describe") { + reader->Read(&describe); + } else if (key == "attrs") { + reader->Read(&attrs); + } else if (key == "inputs") { + reader->Read(&inputs); + bitmask |= 2; + } else if (key == "outputs") { + reader->Read(&outputs); + bitmask |= 4; + } else if (key == "buffers") { + reader->Read(&buffers); + } else if (key == "externs") { + reader->Read(&externs); + } else if (key == "support_dtypes") { + reader->Read(&support_dtypes); + } else if (key == "options") { + reader->Read(&options); + } + } + ICHECK_EQ(bitmask, 1 | 2 | 4) << "name, inputs and outputs should be given for plugin"; + if (externs.size() > 0) { + ICHECK(externs.count("infer_output")) << "infer_output should be given as extern"; + bool has_compute = false; + for (const auto& pair : externs) { + if (StringUtils::EndsWith(pair.first, "_compute")) { + has_compute = true; + } + } + ICHECK(has_compute) << "No compute function found, please check"; + } + if (describe.size() == 0) { + describe = "Plugin " + name + "(" + version + ")"; + } + } +}; + +/*! + * \brief Attribute in Plugin. + */ +class PluginAttrNode : public Object { + public: + /*! \brief The name of attribute. */ + String name; + /*! \brief The type of attribute. */ + String type; + /*! \brief The default_value of attribute. */ + String default_value; + /*! \brief The describe of attribute. */ + String describe; + + /*! \brief Export attribute to json. */ + const JsonPluginAttr ToJson() const; + /*! \brief Load attribute from json struct. */ + void FromJson(const JsonPluginAttr& j_attr); + /*! \brief Load attribute from json string. */ + void FromJson(const std::string& json_str); + + void VisitAttrs(AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("type", &type); + v->Visit("default_value", &default_value); + v->Visit("describe", &describe); + } + + bool SEqualReduce(const PluginAttrNode* other, SEqualReducer equal) const { + return equal(name, other->name) && equal(type, other->type) && + equal(default_value, other->default_value) && equal(describe, other->describe); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(name); + hash_reduce(type); + hash_reduce(default_value); + hash_reduce(describe); + } + + static constexpr const char* _type_key = "msc.core.PluginAttr"; + TVM_DECLARE_FINAL_OBJECT_INFO(PluginAttrNode, Object); +}; + +/*! + * \brief Managed reference to PluginAttrNode. + * \sa PluginAttrNode + */ +class PluginAttr : public ObjectRef { + public: + /*! + * \brief The constructor. + * \param name The name of the attribute. + * \param type The type of the attribute. + * \param default_value The default_value of the attribute. + * \param describe The describe of the attribute. + */ + TVM_DLL PluginAttr(const String& name, const String& type, const String& default_value, + const String& describe); + + /*! + * \brief The json constructor. + * \param j_attr The json describe of the attribute. + */ + TVM_DLL PluginAttr(const JsonPluginAttr& j_attr); + + /*! + * \brief The json constructor. + * \param json_str The json describe of the attribute. + */ + TVM_DLL PluginAttr(const std::string& json_str); + + TVM_DEFINE_OBJECT_REF_METHODS(PluginAttr, ObjectRef, PluginAttrNode); +}; + +/*! + * \brief Tensor in Plugin. + */ +class PluginTensorNode : public Object { + public: + /*! \brief The name of tensor. */ + String name; + /*! \brief The dtype of tensor. */ + String dtype; + /*! \brief The ndim of tensor. */ + Integer ndim; + /*! \brief The device of tensor. */ + String device; + /*! \brief The describe of tensor. */ + String describe; + + /*! \brief Export tensor to json. */ + const JsonPluginTensor ToJson() const; + /*! \brief Load tensor from json struct. */ + void FromJson(const JsonPluginTensor& j_attr); + /*! \brief Load tensor from json string. */ + void FromJson(const std::string& json_str); + + void VisitAttrs(AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("dtype", &dtype); + v->Visit("ndim", &ndim); + v->Visit("device", &device); + v->Visit("describe", &describe); + } + + bool SEqualReduce(const PluginTensorNode* other, SEqualReducer equal) const { + return equal(name, other->name) && equal(dtype, other->dtype) && equal(ndim, other->ndim) && + equal(device, other->device) && equal(describe, other->describe); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(name); + hash_reduce(dtype); + hash_reduce(ndim); + hash_reduce(device); + hash_reduce(describe); + } + + static constexpr const char* _type_key = "msc.core.PluginTensor"; + TVM_DECLARE_FINAL_OBJECT_INFO(PluginTensorNode, Object); +}; + +/*! + * \brief Managed reference to PluginTensorNode. + * \sa PluginTensorNode + */ +class PluginTensor : public ObjectRef { + public: + /*! + * \brief The constructor. + * \param name The name of the tensor. + * \param dtype The dtype of the tensor. + * \param ndim The ndim of the tensor. + * \param device The device of the tensor. + * \param describe The describe of the tensor. + */ + TVM_DLL PluginTensor(const String& name, const String& dtype, const Integer& ndim, + const String& device, const String& describe); + + /*! + * \brief The json constructor. + * \param j_tensor The json describe of the tensor. + */ + TVM_DLL PluginTensor(const JsonPluginTensor& j_tensor); + + /*! + * \brief The json constructor. + * \param json_str The json describe of the tensor. + */ + TVM_DLL PluginTensor(const std::string& json_str); + + TVM_DEFINE_OBJECT_REF_METHODS(PluginTensor, ObjectRef, PluginTensorNode); +}; + +/*! + * \brief Extern symbol in Plugin. + */ +class PluginExternNode : public Object { + public: + /*! \brief The name of extern. */ + String name; + /*! \brief The header of extern. */ + String header; + /*! \brief The source of extern. */ + String source; + /*! \brief The lib of extern. */ + String lib; + /*! \brief The describe of extern. */ + String describe; + + /*! \brief Export extern to json. */ + const JsonPluginExtern ToJson() const; + /*! \brief Load extern from json struct. */ + void FromJson(const JsonPluginExtern& j_attr); + /*! \brief Load extern from json string. */ + void FromJson(const std::string& json_str); + + void VisitAttrs(AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("header", &header); + v->Visit("source", &source); + v->Visit("lib", &lib); + v->Visit("describe", &describe); + } + + bool SEqualReduce(const PluginExternNode* other, SEqualReducer equal) const { + return equal(name, other->name) && equal(header, other->header) && + equal(source, other->source) && equal(lib, other->lib) && + equal(describe, other->describe); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(name); + hash_reduce(header); + hash_reduce(source); + hash_reduce(lib); + hash_reduce(describe); + } + + static constexpr const char* _type_key = "msc.core.PluginExtern"; + TVM_DECLARE_FINAL_OBJECT_INFO(PluginExternNode, Object); +}; + +/*! + * \brief Managed reference to PluginExternNode. + * \sa PluginExternNode + */ +class PluginExtern : public ObjectRef { + public: + /*! + * \brief The constructor. + * \param name The name of the extern. + * \param header The header of the extern. + * \param source The source of the extern. + * \param lib The lib of the extern. + * \param describe The describe of the extern. + */ + TVM_DLL PluginExtern(const String& name, const String& header, const String& source, + const String& lib, const String& describe); + + /*! + * \brief The json constructor. + * \param j_extern The json describe of the extern. + */ + TVM_DLL PluginExtern(const JsonPluginExtern& j_extern); + + /*! + * \brief The json constructor. + * \param json_str The json describe of the extern. + */ + TVM_DLL PluginExtern(const std::string& json_str); + + TVM_DEFINE_OBJECT_REF_METHODS(PluginExtern, ObjectRef, PluginExternNode); +}; + +/*! + * \brief The Plugin in MSC. + */ +class PluginNode : public Object { + public: + /*! \brief The name of plugin. */ + String name; + /*! \brief The version of plugin. */ + String version; + /*! \brief The describe of plugin. */ + String describe; + /*! \brief The attributes of plugin. */ + Array attrs; + /*! \brief The inputs of plugin. */ + Array inputs; + /*! \brief The outputs of plugin. */ + Array outputs; + /*! \brief The buffers of plugin. */ + Array buffers; + /*! \brief The externs of plugin. */ + Map externs; + /*! \brief The support_dtypes of plugin. */ + Map> support_dtypes; + /*! \brief The options of plugin. */ + Map options; + + /*! \brief Export plugin to json. */ + const JsonPlugin ToJson() const; + /*! \brief Load plugin from json struct. */ + void FromJson(const JsonPlugin& j_attr); + /*! \brief Load plugin from json string. */ + void FromJson(const std::string& json_str); + + /*! \brief Find input ref index for dtype. */ + int FindDtypeRefIdx(const PluginTensor& tensor) const; + /*! \brief Find input ref index for device. */ + int FindDeviceRefIdx(const PluginTensor& tensor) const; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("version", &version); + v->Visit("describe", &describe); + v->Visit("attrs", &attrs); + v->Visit("inputs", &inputs); + v->Visit("outputs", &outputs); + v->Visit("buffers", &buffers); + v->Visit("externs", &externs); + v->Visit("support_dtypes", &support_dtypes); + v->Visit("options", &options); + } + + bool SEqualReduce(const PluginNode* other, SEqualReducer equal) const { + return equal(name, other->name) && equal(version, other->version) && + equal(describe, other->describe) && equal(attrs, other->attrs) && + equal(inputs, other->inputs) && equal(outputs, other->outputs) && + equal(buffers, other->buffers) && equal(externs, other->externs) && + equal(support_dtypes, other->support_dtypes) && equal(options, other->options); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(name); + hash_reduce(version); + hash_reduce(describe); + hash_reduce(attrs); + hash_reduce(inputs); + hash_reduce(outputs); + hash_reduce(buffers); + hash_reduce(externs); + hash_reduce(externs); + hash_reduce(support_dtypes); + hash_reduce(options); + } + + static constexpr const char* _type_key = "msc.core.Plugin"; + TVM_DECLARE_FINAL_OBJECT_INFO(PluginNode, Object); +}; + +/*! + * \brief Managed reference to PluginNode. + * \sa PluginNode + */ +class Plugin : public ObjectRef { + public: + /*! + * \brief The constructor. + * \param name The name of the plugin. + * \param version The version of the plugin. + * \param describe The describe of the plugin. + * \param attrs The attrs of the plugin. + * \param inputs The inputs of the plugin. + * \param outputs The outputs of the plugin. + * \param buffers The buffers of the plugin. + * \param externs The externs of the plugin. + * \param support_dtypes The support_dtypes of the plugin. + * \param options The options of the plugin. + */ + TVM_DLL Plugin(const String& name, const String& version, const String& describe, + const Array& attrs, const Array& inputs, + const Array& outputs, const Array& buffers, + const Map& externs, + const Map>& support_dtypes, + const Map& options); + + /*! + * \brief The json constructor. + * \param j_plugin The json describe of the plugin. + */ + TVM_DLL Plugin(const JsonPlugin& j_plugin); + + /*! + * \brief The json constructor. + * \param json_str The json describe of the plugin. + */ + TVM_DLL Plugin(const std::string& json_str); + + TVM_DEFINE_OBJECT_REF_METHODS(Plugin, ObjectRef, PluginNode); +}; + +class PluginRegistry { + public: + /*! + * \brief Register a new plugin. + * \param name The name of the item. + * \param json_str The json_str. + * \return The corresponding entry. + */ + bool Register(const String& name, const String& json_str) { + plugin_map_[name] = Plugin(json_str); + return true; + } + + /*! + * \brief Check if an plugin is registered. + * \param name The name of the item. + * \return Whether the plugin is registered. + */ + bool Registered(const String& name) const { + auto it = plugin_map_.find(name); + return it != plugin_map_.end(); + } + + /*! + * \brief Get an plugin from the registry. + * \param name The name of the item. + * \return The corresponding plugin. + */ + const Plugin Get(const String& name) const { + auto it = plugin_map_.find(name); + ICHECK(it != plugin_map_.end()) << "Can not find plugin " << name; + return it->second; + } + + /*! + * \brief List all the plugin names in the registry. + * \return The plugin names. + */ + Array ListAllNames() const { + Array names; + for (const auto& kv : plugin_map_) { + names.push_back(kv.first); + } + return names; + } + + /*! + * \return a global singleton of the registry. + */ + static PluginRegistry* Global() { + static PluginRegistry* inst = new PluginRegistry(); + return inst; + } + + private: + // map from name to plugins. + std::unordered_map plugin_map_; +}; + +/*! + * \brief List all plugin names. + * \return the corresponding plugin names. + */ +const Array ListPluginNames(); + +/*! + * \brief Get the registered plugin. + * \param name The name of the Plugin. + * \return the corresponding plugin. + */ +const Plugin GetPlugin(const String& name); + +/*! + * \brief Check if an plugin is registered. + * \param name The name of the item. + * \return Whether the plugin is registered. + */ +bool IsPlugin(const String& name); + +} // namespace msc +} // namespace contrib +} // namespace tvm +#endif // TVM_CONTRIB_MSC_CORE_IR_PLUGIN_H_ diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc index bf5f36b4605a..f58f95ae53b0 100644 --- a/src/contrib/msc/core/utils.cc +++ b/src/contrib/msc/core/utils.cc @@ -78,6 +78,34 @@ int CommonUtils::CompareVersion(const Array& given_version, return CompareVersion(int_given_version, int_target_version); } +const String CommonUtils::ToAttrKey(const String& key) { + if (key == "name") { + return msc_attr::kName; + } + if (key == "optype") { + return msc_attr::kOptype; + } + if (key == "op_attrs") { + return msc_attr::kOpattrs; + } + if (key == "layout") { + return msc_attr::kLayout; + } + if (key == "shared_ref") { + return msc_attr::kSharedRef; + } + if (key == "unique") { + return msc_attr::kUnique; + } + if (key == "input_layouts") { + return msc_attr::kInputLayouts; + } + if (key == "consumer_type") { + return msc_attr::kConsumerType; + } + LOG_FATAL << "Unexpected key " << key; +} + bool StringUtils::Contains(const String& src_string, const String& sub_string) { if (src_string.size() == 0) { return false; @@ -150,6 +178,15 @@ const String StringUtils::Join(const Array& sub_strings, const String& j return join_str; } +const String StringUtils::Join(const std::vector& sub_strings, + const std::string& joint) { + Array new_strings; + for (const auto& s : sub_strings) { + new_strings.push_back(s); + } + return Join(new_strings, joint); +} + const String StringUtils::Replace(const String& src_string, const String& old_str, const String& new_str) { String new_string; @@ -433,6 +470,10 @@ TVM_REGISTER_GLOBAL("msc.core.CompareVersion") return Integer(CommonUtils::CompareVersion(given_version, target_version)); }); +TVM_REGISTER_GLOBAL("msc.core.ToAttrKey").set_body_typed([](const String& key) -> String { + return CommonUtils::ToAttrKey(key); +}); + } // namespace msc } // namespace contrib } // namespace tvm diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h index 43eb7bd33d9e..5762c9635206 100644 --- a/src/contrib/msc/core/utils.h +++ b/src/contrib/msc/core/utils.h @@ -28,6 +28,7 @@ #include #include +#include #include #include @@ -39,6 +40,26 @@ using Expr = tvm::RelayExpr; using RelaxCall = tvm::relax::Call; using RelayCall = tvm::relay::Call; +namespace msc_attr { +/*! \brief Mark the name for the expr. */ +constexpr const char* kName = "Name"; +/*! \brief Mark the optype for the expr. */ +constexpr const char* kOptype = "Optype"; +/*! \brief Mark the optype for the expr. */ +constexpr const char* kOpattrs = "Opattrs"; +/*! \brief Mark the layout for the expr. */ +constexpr const char* kLayout = "Layout"; +/*! \brief Mark the share reference for the expr. */ +constexpr const char* kSharedRef = "SharedRef"; + +/*! \brief Mark the unique name for the func. */ +constexpr const char* kUnique = "Unique"; +/*! \brief Mark the input layout for the func. */ +constexpr const char* kInputLayouts = "InputLayouts"; +/*! \brief Mark the consumer type for the func. */ +constexpr const char* kConsumerType = "ConsumerType"; +} // namespace msc_attr + /*! * \brief Utils for Common. */ @@ -64,6 +85,11 @@ class CommonUtils { const std::vector& target_version); TVM_DLL static int CompareVersion(const Array& given_version, const Array& target_version); + /*! + * \brief Get attr key. + * \return The attr key. + */ + TVM_DLL static const String ToAttrKey(const String& key); }; /*! @@ -100,6 +126,8 @@ class StringUtils { * \return The String. */ TVM_DLL static const String Join(const Array& sub_strings, const String& joint); + TVM_DLL static const String Join(const std::vector& sub_strings, + const std::string& joint); /*! * \brief Replace the substring old to new in String. diff --git a/src/contrib/msc/plugin/base_codegen.h b/src/contrib/msc/plugin/base_codegen.h new file mode 100644 index 000000000000..cd5f03ff7716 --- /dev/null +++ b/src/contrib/msc/plugin/base_codegen.h @@ -0,0 +1,674 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/plugin/base_codegen.h + * \brief The codegen for Plugin. + */ +#ifndef TVM_CONTRIB_MSC_PLUGIN_BASE_CODEGEN_H_ +#define TVM_CONTRIB_MSC_PLUGIN_BASE_CODEGEN_H_ + +#include +#include + +#include +#include +#include +#include +#include + +#include "../core/codegen/code_stack.h" +#include "../core/ir/plugin.h" +#include "../core/printer/cpp_printer.h" +#include "../core/printer/python_printer.h" + +namespace tvm { +namespace contrib { +namespace msc { + +using namespace tvm::script::printer; + +/*! + * \brief CodeGen for Plugin + */ +template +class BasePluginCodeGen { + public: + /*! + * \brief The constructor of BasePluginCodeGen + * \param config the options for codegen. + */ + explicit BasePluginCodeGen(const std::string& config = "") { + config_.reset(new ConfigType()); + if (config.size() > 0) { + std::istringstream is(config); + dmlc::JSONReader reader(&is); + reader.Read(config_.get()); + } + } + + virtual ~BasePluginCodeGen() = default; + + /*! \brief Get plugin sources*/ + virtual const Map GetBuildSources(const std::string& print_options = "") { + Map sources; + // plugin sources + for (const auto& name : ListPluginNames()) { + const auto& plugin = GetPlugin(name); + // attr declare + const String& attr_macro = "TVM_CONTRIB_MSC_" + StringUtils::Upper(plugin->name) + "_ATTR_H_"; + this->stack_.line("#ifndef " + attr_macro) + .line("#define " + attr_macro) + .line() + .line("#include \"plugin_utils.h\"") + .line(); + StartNamespace(); + CodeGenAttrDeclare(plugin); + EndNamespace(); + this->stack_.line("#endif // " + attr_macro); + sources.Set(plugin->name + "_attr.h", ToCppSource(print_options)); + // attr define + this->stack_.line("#include \"" + plugin->name + "_attr.h\"").line(); + StartNamespace(); + CodeGenAttrDefine(plugin); + EndNamespace(); + sources.Set(plugin->name + "_attr.cc", ToCppSource(print_options)); + // op decalre + const String& op_macro = "TVM_CONTRIB_MSC_" + StringUtils::Upper(plugin->name) + "_OP_H_"; + this->stack_.line("#ifndef " + op_macro).line("#define " + op_macro).line(); + CodeGenOpHeader(plugin); + StartNamespace(); + CodeGenOpDeclare(plugin); + EndNamespace(); + this->stack_.line("#endif // " + op_macro); + sources.Set(plugin->name + "_op.h", ToCppSource(print_options)); + // op define + this->stack_.line("#include \"" + plugin->name + "_op.h\"").line(); + StartNamespace(); + CodeGenOpDefine(plugin); + EndNamespace(); + sources.Set(plugin->name + "_op.cc", ToCppSource(print_options)); + // op runtime + if (this->config()->with_runtime) { + CodeGenOpHeader(plugin); + StartNamespace(); + CodeGenOpRuntime(plugin); + EndNamespace(); + sources.Set(plugin->name + "_runtime.cc", ToCppSource(print_options)); + } + } + // cmakelists + std::set devices; + for (const auto& name : ListPluginNames()) { + const auto& plugin = GetPlugin(name); + for (const auto& pair : plugin->externs) { + if (StringUtils::EndsWith(pair.first, "_compute")) { + devices.insert(StringUtils::Replace(pair.first, "_compute", "")); + } + } + } + CodeGenCmake(devices); + sources.Set("CMakeLists.txt", ToCppSource(print_options)); + return sources; + } + + /*! \brief Get manager sources*/ + virtual const Map GetManagerSources(const std::string& print_options = "") { + Map sources; + CodeGenManagerDepends(); + this->stack_.class_def("PluginManager(object)").class_start(); + CodeGenManagerMethods(); + for (const auto& name : ListPluginNames()) { + CodeGenOpBuilder(GetPlugin(name)); + } + if (this->config()->need_convert) { + Map symbols; + this->stack_.func_def("get_convert_map") + .func_decorator("classmethod") + .func_arg("cls", "object") + .func_start(); + CodeGenConvertDepends(); + for (const auto& name : ListPluginNames()) { + const auto& plugin = GetPlugin(name); + const auto& symbol = CodeGenOpConvert(plugin); + symbols.Set(plugin, symbol); + } + this->stack_.assign("converters", "{}"); + for (const auto& pair : symbols) { + this->stack_.assign(DocUtils::ToIndex("converters", DocUtils::ToStr(pair.second)), + ConverterName(pair.first)); + } + this->stack_.func_end("converters"); + } + this->stack_.class_end(); + sources.Set("manager.py", ToPySource(print_options)); + return sources; + } + + protected: + /*! \brief Header of plugin files*/ + virtual void CodeGenOpHeader(const Plugin& plugin) { + this->stack_.line("#include \"" + plugin->name + "_attr.h\""); + std::set include_headers; + for (const auto& pair : plugin->externs) { + if (pair.second->header.size() > 0 && !include_headers.count(pair.second->header)) { + this->stack_.line("#include \"" + pair.second->header + "\""); + include_headers.insert(pair.second->header); + } + } + this->stack_.line(); + } + + /*! \brief Start the namespace*/ + void StartNamespace() { + this->stack_.line("namespace tvm {") + .line("namespace contrib {") + .line("namespace msc {") + .line("namespace plugin {") + .line(); + } + + /*! \brief End the namespace*/ + void EndNamespace() { + this->stack_.line("} // namespace plugin") + .line("} // namespace msc") + .line("} // namespace contrib") + .line("} // namespace tvm"); + } + + /*! \brief Codegen safe call extern*/ + void CodeGenSafeCall(const PluginExtern& extern_func, + const Array& call_args = Array(), const String& ret = "") { + this->stack_.scope_start("try {").func_call(extern_func->name, ret); + for (const auto& arg : call_args) { + this->stack_.call_arg(arg); + } + this->stack_.scope_end() + .scope_start("} catch (const std::exception& exc) {") + .line("std::cerr << \"Failed to run extern " + extern_func->name + + " : \" << exc.what() << std::endl;") + .line("throw std::runtime_error(\"Failed to run extern " + extern_func->name + "\");") + .scope_end() + .line("}"); + } + + /*! \brief Codegen plugin attr declare*/ + virtual void CodeGenAttrDeclare(const Plugin& plugin) { + this->stack_.struct_start(MetaAttrCls(plugin)).comment("define attributes"); + for (const auto& attr : plugin->attrs) { + this->stack_.declare(ToCppType(attr->type), attr->name); + if (attr->default_value.size() > 0) { + this->stack_.declare_arg(attr->default_value); + } + } + this->stack_.line() + .comment("print method") + .func_def("operator<<", "friend std::ostream&") + .func_arg("out", "std::ostream&") + .func_arg("attrs", "const " + MetaAttrCls(plugin) + "&") + .func_start() + .line("out << \"[" + MetaAttrCls(plugin) + "] : \";"); + for (const auto& attr : plugin->attrs) { + this->stack_.line("out << \"| " + attr->name + "(" + attr->type + ")=\" << attrs." + + attr->name + ";"); + } + this->stack_.func_end("out").struct_end(); + } + + /*! \brief Codegen plugin attr define*/ + virtual void CodeGenAttrDefine(const Plugin& plugin) {} + + /*! \brief Codegen plugin op declare*/ + virtual void CodeGenOpDeclare(const Plugin& plugin) = 0; + + /*! \brief Codegen plugin op define*/ + virtual void CodeGenOpDefine(const Plugin& plugin) = 0; + + /*! \brief Codegen plugin runtime*/ + virtual void CodeGenOpRuntime(const Plugin& plugin) {} + + /*! \brief Codegen cmake file*/ + virtual void CodeGenCmake(const std::set& devices) { + CodeGenPreCmake(devices); + CodeGenPostCmake(devices); + } + + /*! \brief Codegen cmake start*/ + void CodeGenPreCmake(const std::set& devices, + const Map& extra_flags = Map()) { + const auto& p_name = this->config()->project_name; + stack_.line("cmake_minimum_required(VERSION " + this->config()->cmake_version + " FATAL_ERROR)") + .line("project(" + p_name + ")"); + if (devices.count("cuda")) { + stack_.line("find_package(CUDA)").line("add_definitions(-DPLUGIN_ENABLE_CUDA)"); + } + stack_.line(); + for (const auto& pair : extra_flags) { + if (pair.second.size() == 0) { + stack_.line("add_definitions(-D" + pair.first + ")"); + } else { + stack_.line("add_definitions(-D" + pair.first + "=" + pair.second + ")"); + } + } + for (const auto& pair : this->config()->flags) { + if (pair.second.size() == 0) { + stack_.line("add_definitions(-D" + pair.first + ")"); + } else { + stack_.line("add_definitions(-D" + pair.first + "=" + pair.second + ")"); + } + } + stack_.line(); + } + + /*! \brief Codegen cmake end*/ + void CodeGenPostCmake(const std::set& devices, + const Array& extra_includes = Array(), + const Array& extra_libs = Array()) { + const auto& p_name = this->config()->project_name; + stack_.line() + .line("file(GLOB_RECURSE PLUGIN_HEADERS src/*.h)") + .line("file(GLOB_RECURSE PLUGIN_CC_SRCS src/*.cc)"); + if (devices.count("cuda")) { + stack_.line("file(GLOB_RECURSE PLUGIN_CU_SRCS src/*.cu)"); + } + if (devices.count("cuda")) { + stack_.line("cuda_add_library(" + p_name + " SHARED ${PLUGIN_CC_SRCS} ${PLUGIN_CU_SRCS})"); + } else { + stack_.line("add_library(" + p_name + " SHARED ${PLUGIN_CC_SRCS})"); + } + // define includes + String includes = StringUtils::Join(extra_includes, " "); + if (this->config()->includes.size() > 0) { + includes = includes + " " + StringUtils::Join(this->config()->includes, " "); + } + if (includes.size() > 0) { + stack_.line("target_include_directories(" + p_name + " PUBLIC " + includes + ")"); + } + // define libs + String link_libs = StringUtils::Join(extra_libs, " "); + const auto& libs = StringUtils::Join(this->config()->libs, " "); + if (libs.size() > 0) { + link_libs = link_libs + " " + libs; + } + if (link_libs.size() > 0) { + stack_.line("target_link_libraries(" + p_name + " " + link_libs + ")"); + } + const auto& install_dir = this->config()->install_dir; + if (install_dir.size() > 0) { + stack_.line() + .line("SET(LIBRARY_OUTPUT_PATH " + install_dir + "/lib)") + .line("file(COPY ${PLUGIN_HEADERS} DESTINATION " + install_dir + "/include)"); + if (this->config()->libs.size() > 0) { + stack_.line("file(COPY " + libs + " DESTINATION " + install_dir + "/lib)"); + } + } + } + + /*! \brief Codegen manager depends*/ + virtual void CodeGenManagerDepends() { + this->stack_.line("import os") + .line("import shutil") + .line("import ctypes") + .line("from typing import Any, List, Dict") + .line(); + } + + /*! \brief Codegen manager methods*/ + virtual void CodeGenManagerMethods() { + // init method + stack_.func_def("__init__") + .func_arg("self", "object") + .func_arg("root", "str", "None") + .func_start() + .cond_if("root is None") + .assign("root", "os.path.dirname(__name__)") + .cond_end() + .assign(DocUtils::ToAttrAccess("self", "_lib_folder"), "os.path.join(root, \"lib\")") + .func_call("assert") + .inplace_start("os.path.isdir") + .call_arg(DocUtils::ToAttrAccess("self", "_lib_folder")) + .inplace_end() + .assign(DocUtils::ToAttrAccess("self", "_include_folder"), + "os.path.join(root, \"include\")") + .func_call("assert") + .inplace_start("os.path.isdir") + .call_arg(DocUtils::ToAttrAccess("self", "_include_folder")) + .inplace_end() + .assign(DocUtils::ToAttrAccess("self", "_manager_file"), + "os.path.join(root, \"manager.py\")") + .func_call("assert") + .inplace_start("os.path.isfile") + .call_arg(DocUtils::ToAttrAccess("self", "_manager_file")) + .inplace_end() + .func_call("setup", "", "self") + .func_end(); + // list headers + this->stack_.func_def("list_includes") + .func_arg("self", "object") + .func_arg("as_abs", "bool", "False") + .func_start() + .assign("includes", "[]") + .for_start("f", "os.listdir(self._include_folder)") + .cond_if("as_abs") + .func_call("append", "", "includes") + .inplace_start("os.path.join") + .call_arg(DocUtils::ToAttrAccess("self", "_include_folder")) + .call_arg("f") + .inplace_end() + .cond_else() + .func_call("append", "", "includes") + .call_arg("f") + .cond_end() + .for_end() + .func_end("includes"); + // copy the headers + this->stack_.func_def("copy_includes") + .func_arg("self", "object") + .func_arg("dst", "str") + .func_start() + .cond_if("not os.path.isdir(dst)") + .func_call("makedirs", "", "os") + .call_arg("dst") + .cond_end() + .for_start("header", "os.listdir(self._include_folder)") + .func_call("shutil.copyfile") + .inplace_start("os.path.join") + .call_arg(DocUtils::ToAttrAccess("self", "_include_folder")) + .call_arg("header") + .inplace_end() + .inplace_start("os.path.join") + .call_arg("dst") + .call_arg("header") + .inplace_end() + .for_end() + .func_end(); + // list libs + this->stack_.func_def("list_libs") + .func_arg("self", "object") + .func_arg("as_abs", "bool", "False") + .func_start() + .assign("libs", "[]") + .for_start("f", "os.listdir(self._lib_folder)") + .cond_if("as_abs") + .func_call("append", "", "libs") + .inplace_start("os.path.join") + .call_arg(DocUtils::ToAttrAccess("self", "_lib_folder")) + .call_arg("f") + .inplace_end() + .cond_else() + .func_call("append", "", "libs") + .call_arg("f") + .cond_end() + .for_end() + .func_end("libs"); + // copy the libs + this->stack_.func_def("copy_libs") + .func_arg("self", "object") + .func_arg("dst", "str") + .func_start() + .cond_if("not os.path.isdir(dst)") + .func_call("makedirs", "", "os") + .call_arg("dst") + .cond_end() + .for_start("lib", "os.listdir(self._lib_folder)") + .func_call("shutil.copyfile") + .inplace_start("os.path.join") + .call_arg(DocUtils::ToAttrAccess("self", "_lib_folder")) + .call_arg("lib") + .inplace_end() + .inplace_start("os.path.join") + .call_arg("dst") + .call_arg("lib") + .inplace_end() + .for_end() + .func_end(); + // export method + this->stack_.func_def("export") + .func_arg("self", "object") + .func_arg("dst", "str") + .func_start() + .func_call("copy_includes", "", "self") + .inplace_start("os.path.join") + .call_arg("dst") + .call_arg(DocUtils::ToStr("include")) + .inplace_end() + .func_call("copy_libs", "", "self") + .inplace_start("os.path.join") + .call_arg("dst") + .call_arg(DocUtils::ToStr("lib")) + .inplace_end() + .func_call("shutil.copyfile") + .call_arg(DocUtils::ToAttrAccess("self", "_manager_file")) + .inplace_start("os.path.join") + .call_arg("dst") + .call_arg(DocUtils::ToStr("manager.py")) + .inplace_end() + .func_end(); + // get op names + this->stack_.func_def("get_op_names", "List[str]") + .func_arg("self", "object") + .func_start() + .assign("names", "[]"); + for (const auto& name : ListPluginNames()) { + this->stack_.func_call("append", "", "names").call_arg(DocUtils::ToStr(name)); + } + this->stack_.func_end("names"); + // get ops info + this->stack_.func_def("get_ops_info", "dict") + .func_arg("self", "object") + .func_start() + .assign("info", "{}"); + for (const auto& name : ListPluginNames()) { + ICHECK(this->config()->ops_info.count(name)) << "Can not find op info for " << name; + const auto& info = this->config()->ops_info[name]; + this->stack_.assign(DocUtils::ToIndex("info", DocUtils::ToStr(name)), info); + } + this->stack_.func_end("info"); + } + + /*! \brief Codegen manager for plugin*/ + virtual void CodeGenOpBuilder(const Plugin& plugin) {} + + /*! \brief Codegen convert depends*/ + virtual void CodeGenConvertDepends() { + this->stack_.line("from tvm import relax") + .line("from tvm.relax import call_dps_packed") + .line("from tvm.contrib.msc.plugin import utils as plugin_utils") + .line("from tvm.contrib.msc.plugin.op import _ffi_api as _plugin_api") + .line("from tvm.contrib.msc.core import utils as msc_utils") + .line(); + } + + /*! \brief Codegen convert function for plugin*/ + virtual const String CodeGenOpConvert(const Plugin& plugin) { return plugin->name; } + + /*! \brief Change code stack to cpp source*/ + const String ToCppSource(const std::string& print_options = "") { + CppPrinter printer(print_options); + for (const auto& d : this->stack_.GetDocs()) { + printer.Append(d); + } + this->stack_.Reset(); + return printer.GetString(); + } + + /*! \brief Change code stack to python source*/ + const String ToPySource(const std::string& print_options = "") { + PythonPrinter printer(print_options); + for (const auto& d : this->stack_.GetDocs()) { + printer.Append(d); + } + this->stack_.Reset(); + return printer.GetString(); + } + + std::vector> GetDtypeMatrix(const Plugin& plugin) { + std::vector> matrix; + if (plugin->support_dtypes.size() == 0) { + std::unordered_map dtypes; + for (size_t i = 0; i < plugin->inputs.size(); i++) { + dtypes[i] = plugin->inputs[i]->dtype; + } + matrix.push_back(dtypes); + } else { + Array templates; + Array> condidates; + for (const auto& pair : plugin->support_dtypes) { + templates.push_back(pair.first); + condidates.push_back(pair.second); + } + for (const auto& t_dtypes : ArrayUtils::Product(condidates)) { + std::unordered_map dtypes; + for (size_t i = 0; i < templates.size(); i++) { + for (size_t in_idx = 0; in_idx < plugin->inputs.size(); in_idx++) { + if (plugin->inputs[in_idx]->dtype == templates[i]) { + dtypes[in_idx] = t_dtypes[i]; + } + } + } + for (size_t i = 0; i < plugin->inputs.size(); i++) { + if (dtypes.count(i)) { + continue; + } + dtypes[i] = plugin->inputs[i]->dtype; + } + matrix.push_back(dtypes); + } + } + return matrix; + } + + const Map GetTensorDtypes(const Plugin& plugin, + const std::unordered_map& dtypes) { + Map tensor_dtypes; + for (const auto& pair : dtypes) { + const String& ref_dtype = plugin->inputs[pair.first]->dtype; + for (const auto& t : plugin->inputs) { + if (t->dtype == ref_dtype) { + tensor_dtypes.Set(t->name, pair.second); + } + } + for (const auto& t : plugin->outputs) { + if (t->dtype == ref_dtype) { + tensor_dtypes.Set(t->name, pair.second); + } + } + for (const auto& t : plugin->buffers) { + if (t->dtype == ref_dtype) { + tensor_dtypes.Set(t->name, pair.second); + } + } + } + return tensor_dtypes; + } + + /*! \brief Change plugin comment in python*/ + const String GetPyComment(const Plugin& plugin) { + String comment = "Python wrapper for " + plugin->name + "\nInputs\n------"; + for (const auto& t : plugin->inputs) { + comment = comment + "\n" + t->name + ": " + t->dtype + "\n " + t->describe; + } + comment = comment + "\nOutputs\n-------"; + for (const auto& t : plugin->outputs) { + comment = comment + "\n" + t->name + ": " + t->dtype + "\n " + t->describe; + } + if (plugin->attrs.size() > 0) { + comment = comment + "\nAttributes\n-----------"; + for (const auto& a : plugin->attrs) { + comment = comment + "\n" + a->name + ": " + ToPyType(a->type) + "\n " + a->describe; + } + } + return comment; + } + + /*! \brief Get class name for meta attrs*/ + const String MetaAttrCls(const Plugin& plugin) const { return plugin->name + "MetaAttr"; } + + /*! \brief Get converter name for plugin*/ + const String ConverterName(const Plugin& plugin) const { return plugin->name + "Converter"; } + + /*! \brief Check if the type is list type. */ + bool IsListType(const String& type) { return StringUtils::StartsWith(type, "list"); } + + /*! \brief Get type of element. */ + const String GetEleType(const String& type) { + if (!IsListType(type)) { + return ""; + } + return StringUtils::Replace(StringUtils::Replace(type, "list(", ""), ")", ""); + } + + /*! \brief Type name in cpp*/ + virtual const String ToCppType(const String& type) { + if (IsListType(type)) { + const auto& ele_type = GetEleType(type); + return "std::vector<" + ToCppType(ele_type) + ">"; + } + if (type == "int64") { + return "int64_t"; + } + if (type == "int32" || type == "int") { + return "int32_t"; + } + if (type == "int8") { + return "int8_t"; + } + if (type == "string") { + return "std::string"; + } + return type; + } + + /*! \brief Type name in python*/ + virtual const String ToPyType(const String& type) { + if (IsListType(type)) { + const auto& ele_type = GetEleType(type); + return "List[" + ToPyType(ele_type) + "]"; + } + if (type == "int64" || type == "int32" || type == "int" || type == "int8") { + return "int"; + } + if (type == "string") { + return "str"; + } + return type; + } + + /*! + * \brief Compare version with version in config + * 0 for same version, 1 for greater version, -1 for less version + */ + int CompareVersion(size_t major, size_t minor, size_t patch) { + return CommonUtils::CompareVersion(this->config()->version, {major, minor, patch}); + } + + /*! \brief The config of plugin codegen*/ + const std::shared_ptr config() { return config_; } + + /*! \brief The stack of codes*/ + CodeStack stack_; + + private: + std::shared_ptr config_; +}; + +} // namespace msc +} // namespace contrib +} // namespace tvm +#endif // TVM_CONTRIB_MSC_PLUGIN_BASE_CODEGEN_H_ diff --git a/src/contrib/msc/plugin/codegen_utils.h b/src/contrib/msc/plugin/codegen_utils.h new file mode 100644 index 000000000000..e61a0944ae4a --- /dev/null +++ b/src/contrib/msc/plugin/codegen_utils.h @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/plugin/codegen_utils.h + * \brief Common utilities for print. + */ +#ifndef TVM_CONTRIB_MSC_PLUGIN_CODEGEN_UTILS_H_ +#define TVM_CONTRIB_MSC_PLUGIN_CODEGEN_UTILS_H_ + +#include +#include +#include + +namespace tvm { +namespace contrib { +namespace msc { + +#define PLUGIN_CODEGEN_CONFIG_MEMBERS \ + bool need_convert{false}; \ + bool with_runtime{false}; \ + std::string project_name{"msc_plugin"}; \ + std::string cmake_version{"3.5"}; \ + std::string install_dir; \ + std::vector version{0, 0, 0}; \ + std::vector includes; \ + std::vector libs; \ + std::unordered_map flags; \ + std::unordered_map ops_info; + +#define PLUGIN_CODEGEN_CONFIG_PARSE \ + if (key == "need_convert") { \ + reader->Read(&need_convert); \ + } else if (key == "with_runtime") { \ + reader->Read(&with_runtime); \ + } else if (key == "cmake_version") { \ + reader->Read(&cmake_version); \ + } else if (key == "project_name") { \ + reader->Read(&project_name); \ + } else if (key == "install_dir") { \ + reader->Read(&install_dir); \ + } else if (key == "version") { \ + reader->Read(&version); \ + } else if (key == "includes") { \ + reader->Read(&includes); \ + } else if (key == "libs") { \ + reader->Read(&libs); \ + } else if (key == "flags") { \ + reader->Read(&flags); \ + } else if (key == "ops_info") { \ + reader->Read(&ops_info); \ + } else { \ + LOG(FATAL) << "Do not support key " << key; \ + } + +} // namespace msc +} // namespace contrib +} // namespace tvm +#endif // TVM_CONTRIB_MSC_PLUGIN_CODEGEN_UTILS_H_ diff --git a/src/contrib/msc/plugin/tensorrt_codegen.cc b/src/contrib/msc/plugin/tensorrt_codegen.cc new file mode 100644 index 000000000000..e54b9eedfea8 --- /dev/null +++ b/src/contrib/msc/plugin/tensorrt_codegen.cc @@ -0,0 +1,901 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/plugin/tensorrt_codegen.cc + */ +#include "tensorrt_codegen.h" + +#include +namespace tvm { +namespace contrib { +namespace msc { + +void TensorRTPluginCodeGen::CodeGenAttrDeclare(const Plugin& plugin) { + BasePluginCodeGen::CodeGenAttrDeclare(plugin); + const auto& attr_name = MetaAttrCls(plugin); + // serialize size for attr + stack_.comment("serialize size").func_def(attr_name + "_serialize_size", "size_t"); + // serialize method for attr + stack_.comment("serialize method") + .func_def(attr_name + "_serialize", "char*") + .func_arg("meta_attr", "const " + attr_name + "&") + .func_arg("buffer", "char*"); + // deserialize method for attr + stack_.comment("deserialize method") + .func_def(attr_name + "_deserialize", "const char*") + .func_arg("meta_attr", attr_name + "&") + .func_arg("buffer", "const char*"); + // attr to field + stack_.comment("meta attr to field") + .func_def(attr_name + "_to_fields") + .func_arg("fields", "std::vector&"); + // attr from field + stack_.comment("meta attr from field") + .func_def(attr_name + "_from_fields", "const " + attr_name) + .func_arg("fields", "const PluginField*"); +} + +void TensorRTPluginCodeGen::CodeGenAttrDefine(const Plugin& plugin) { + const auto& attr_name = MetaAttrCls(plugin); + // serialize size for attr + stack_.func_def(attr_name + "_serialize_size", "size_t").func_start().assign("size", 0, "size_t"); + for (const auto& a : plugin->attrs) { + stack_.comment("attr " + a->name + "(" + a->type + ")"); + if (IsListType(a->type)) { + LOG_FATAL << "attribute type " << a->type << " is not supported"; + const auto& ele_type = GetEleType(a->type); + stack_.assign("size", "size + sizeof(size_t)") + .for_start("a", DocUtils::ToAttrAccess("meta_attr", a->name)) + .assign("size", "size + sizeof(" + ToCppType(ele_type) + ")") + .for_end(); + } else { + stack_.assign("size", "size + sizeof(" + ToCppType(a->type) + ")"); + } + } + stack_.func_end("size"); + // serialize method for attr + stack_.func_def(attr_name + "_serialize", "char*") + .func_arg("meta_attr", "const " + attr_name + "&") + .func_arg("buffer", "char*") + .func_start() + .assign("start", "buffer", "const char*"); + for (const auto& a : plugin->attrs) { + stack_.func_call("TRTUtils::ValToBuffer") + .call_arg("buffer") + .call_arg(DocUtils::ToAttrAccess("meta_attr", a->name)); + } + stack_.func_call(attr_name + "_serialize_size", DocUtils::ToDeclare("size_t", "expected")) + .line("assert(buffer == start + expected);") + .func_end("buffer"); + // deserialize method for attr + stack_.func_def(attr_name + "_deserialize", "const char*") + .func_arg("meta_attr", attr_name + "&") + .func_arg("buffer", "const char*") + .func_start() + .assign("start", "buffer", "const char*"); + for (const auto& a : plugin->attrs) { + stack_.func_call("TRTUtils::ValFromBuffer") + .call_arg("buffer") + .call_arg(DocUtils::ToAttrAccess("meta_attr", a->name)); + } + stack_.func_call(attr_name + "_serialize_size", DocUtils::ToDeclare("size_t", "expected")) + .line("assert(buffer == start + expected);") + .func_end("buffer"); + // attr to field + stack_.func_def(attr_name + "_to_fields") + .func_arg("fields", "std::vector&") + .func_start(); + for (const auto& a : plugin->attrs) { + stack_.func_call("emplace_back", "", "fields") + .inplace_start("TRTUtils::ToField") + .call_arg(DocUtils::ToStr(a->name)) + .call_arg(DocUtils::ToStr(a->type)) + .inplace_end(); + } + stack_.func_end(); + // attr from field + stack_.func_def(attr_name + "_from_fields", "const " + attr_name) + .func_arg("fields", "const PluginField*") + .func_start() + .declare(attr_name, "meta_attr") + .for_start("i", 0, plugin->attrs.size()); + for (size_t i = 0; i < plugin->attrs.size(); i++) { + const auto& attr = plugin->attrs[i]; + const String& cond = "strcmp(fields[i].name, \"" + attr->name + "\") == 0"; + if (i == 0) { + stack_.switch_start(cond); + } else { + stack_.switch_case(cond); + } + stack_.func_call("TRTUtils::FromField") + .call_arg(DocUtils::ToIndex("fields", "i")) + .call_arg(DocUtils::ToAttrAccess("meta_attr", attr->name)); + } + stack_.switch_end().for_end().func_end("meta_attr"); +} + +void TensorRTPluginCodeGen::CodeGenOpHeader(const Plugin& plugin) { + BasePluginCodeGen::CodeGenOpHeader(plugin); + stack_.line("using namespace nvinfer1;").line(); +} + +void TensorRTPluginCodeGen::CodeGenOpDeclare(const Plugin& plugin) { + if (!IsMixPrecision(plugin)) { + // static plugin op + const auto& op_static = OpCls(plugin, false); + stack_.class_def(op_static + " : public IPluginV2").class_start().scope_start("public:"); + CodegenOpCommonMethods(plugin, false, true); + stack_.comment("special methods for " + op_static) + .func_def("getOutputDimensions", "Dims") + .func_decorator("noexcept override") + .func_arg("index", "int") + .func_arg("in_dims", "const Dims*") + .func_arg("n_inputs", "int") + .func_def("configureWithFormat") + .func_decorator("noexcept override") + .func_arg("in_dims", "const Dims*") + .func_arg("n_inputs", "int") + .func_arg("out_dims", "const Dims*") + .func_arg("n_outputs", "int") + .func_arg("dtype", "DataType") + .func_arg("format", "PluginFormat") + .func_arg("max_batch", "int") + .func_def("supportsFormat", "bool") + .func_decorator("const noexcept override") + .func_arg("dtype", "DataType") + .func_arg("format", "PluginFormat") + .func_def("getWorkspaceSize", "size_t") + .func_decorator("const noexcept override") + .func_arg("max_batch", "int") + .func_def("enqueue", "int") + .func_decorator("noexcept override") + .func_arg("batch_size", "int") + .func_arg("inputs", "const void* const*") + .func_arg("outputs", "void* const*") + .func_arg("workspace", "void*") + .func_arg("stream", "cudaStream_t") + .scope_end(); + CodegenOpMembers(plugin, false); + stack_.class_end(); + + // static plugin creator + CodegenCreator(plugin, false, true); + } + // dynamic plugin op + const auto& op_dynamic = OpCls(plugin, true); + stack_.class_def(op_dynamic + " : public IPluginV2DynamicExt") + .class_start() + .scope_start("public:"); + CodegenOpCommonMethods(plugin, true, true); + stack_.comment("special methods for " + op_dynamic) + .func_def("getOutputDataType", "DataType") + .func_decorator("const noexcept override") + .func_arg("index", "int") + .func_arg("in_types", "const DataType*") + .func_arg("n_inputs", "int") + .func_def("getOutputDimensions", "DimsExprs") + .func_decorator("noexcept override") + .func_arg("index", "int") + .func_arg("in_dims", "const DimsExprs*") + .func_arg("n_inputs", "int") + .func_arg("builder", "IExprBuilder&") + .func_def("configurePlugin") + .func_decorator("noexcept override") + .func_arg("in_descs", "const DynamicPluginTensorDesc*") + .func_arg("n_inputs", "int") + .func_arg("out_descs", "const DynamicPluginTensorDesc*") + .func_arg("n_outputs", "int") + .func_def("supportsFormatCombination", "bool") + .func_decorator("noexcept override") + .func_arg("pos", "int") + .func_arg("io_desc", "const PluginTensorDesc*") + .func_arg("n_inputs", "int") + .func_arg("n_outputs", "int") + .func_def("getWorkspaceSize", "size_t") + .func_decorator("const noexcept override") + .func_arg("in_descs", "const PluginTensorDesc*") + .func_arg("n_inputs", "int") + .func_arg("out_descs", "const PluginTensorDesc*") + .func_arg("n_outputs", "int") + .func_def("enqueue", "int") + .func_decorator("noexcept override") + .func_arg("input_descs", "const PluginTensorDesc*") + .func_arg("output_descs", "const PluginTensorDesc*") + .func_arg("inputs", "const void* const*") + .func_arg("outputs", "void* const*") + .func_arg("workspace", "void*") + .func_arg("stream", "cudaStream_t") + .scope_end(); + CodegenOpMembers(plugin, true); + stack_.class_end(); + + // dynamic plugin creator + CodegenCreator(plugin, true, true); +} + +void TensorRTPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { + if (!IsMixPrecision(plugin)) { + // static op + const auto& op_static = OpCls(plugin, false); + CodegenOpCommonMethods(plugin, false, false); + // getOutputDimensions + stack_.func_def(op_static + "::getOutputDimensions", "Dims") + .func_decorator("noexcept") + .func_arg("index", "int") + .func_arg("in_dims", "const Dims*") + .func_arg("n_inputs", "int") + .func_start(); + CodegenOutputInfer(plugin, false); + stack_ + .func_call("shape", DocUtils::ToDeclare("MetaShape", "out_shape"), + DocUtils::ToIndex("output_metas_", "index")) + .func_call("TRTUtils::ToDims", DocUtils::ToDeclare("Dims", "out_dims")) + .call_arg("out_shape") + .func_end("out_dims"); + // configureWithFormat + stack_.func_def(op_static + "::configureWithFormat") + .func_decorator("noexcept") + .func_arg("in_dims", "const Dims*") + .func_arg("n_inputs", "int") + .func_arg("out_dims", "const Dims*") + .func_arg("n_outputs", "int") + .func_arg("dtype", "DataType") + .func_arg("format", "PluginFormat") + .func_arg("max_batch", "int") + .func_start() + .assign("dtype_", "dtype") + .line("assert(n_outputs == " + std::to_string(plugin->outputs.size()) + ");"); + CodegenOutputInfer(plugin, false); + stack_.func_end(); + // supportsFormat + stack_.func_def(op_static + "::supportsFormat", "bool") + .func_decorator("const noexcept") + .func_arg("dtype", "DataType") + .func_arg("format", "PluginFormat") + .func_start() + .declare("bool", "support"); + size_t cnt = 0; + for (const auto& dtypes : GetDtypeMatrix(plugin)) { + const String& cond = "dtype_ == TRTUtils::ToDataType(\"" + dtypes.at(0) + "\")"; + if (cnt == 0) { + stack_.switch_start(cond); + } else { + stack_.switch_case(cond); + } + stack_.assign("support", true); + cnt++; + } + stack_.switch_case().assign("support", false).switch_end().func_end("support"); + // getWorkspaceSize + stack_.func_def(op_static + "::getWorkspaceSize", "size_t") + .func_decorator("const noexcept") + .func_arg("max_batch", "int") + .func_start() + .assign("size", 0, "size_t"); + if (plugin->externs.count("infer_buffer")) { + CodegenBufferInfer(plugin); + } + stack_.func_end("size"); + // enqueue + stack_.func_def(op_static + "::enqueue", "int") + .func_decorator("noexcept") + .func_arg("batch_size", "int") + .func_arg("inputs", "const void* const*") + .func_arg("outputs", "void* const*") + .func_arg("workspace", "void*") + .func_arg("stream", "cudaStream_t") + .func_start(); + CodegenEnqueue(plugin, false); + stack_.func_end(0); + + // static creator + CodegenCreator(plugin, false, false); + } + // dynamic op + const auto& op_dynamic = OpCls(plugin, true); + CodegenOpCommonMethods(plugin, true, false); + // getOutputDataType + stack_.func_def(op_dynamic + "::getOutputDataType", "DataType") + .func_decorator("const noexcept") + .func_arg("index", "int") + .func_arg("in_types", "const DataType*") + .func_arg("n_inputs", "int") + .func_start() + .declare("DataType", "dtype"); + for (size_t i = 0; i < plugin->outputs.size(); i++) { + if (i == 0) { + stack_.switch_start("index == " + std::to_string(i)); + } else { + stack_.switch_case("index == " + std::to_string(i)); + } + int ref = plugin->FindDtypeRefIdx(plugin->outputs[i]); + if (ref >= 0) { + stack_.assign("dtype", DocUtils::ToIndex("in_types", ref)); + } else { + stack_.func_call("TRTUtils::ToDataType", "dtype") + .call_arg(DocUtils::ToStr(plugin->outputs[i]->dtype)); + } + } + stack_.switch_end().func_end("dtype"); + // getOutputDimensions + stack_.func_def(op_dynamic + "::getOutputDimensions", "DimsExprs") + .func_decorator("noexcept") + .func_arg("index", "int") + .func_arg("in_dims", "const DimsExprs*") + .func_arg("n_inputs", "int") + .func_arg("builder", "IExprBuilder&") + .func_start(); + CodegenOutputInfer(plugin, false); + stack_ + .func_call("shape", DocUtils::ToDeclare("MetaShape", "out_shape"), + DocUtils::ToIndex("output_metas_", "index")) + .func_call("TRTUtils::ToDimsExprs", DocUtils::ToDeclare("DimsExprs", "out_dims")) + .call_arg("out_shape") + .call_arg("builder") + .func_end("out_dims"); + // configurePlugin + stack_.func_def(op_dynamic + "::configurePlugin") + .func_decorator("noexcept") + .func_arg("in_descs", "const DynamicPluginTensorDesc*") + .func_arg("n_inputs", "int") + .func_arg("out_descs", "const DynamicPluginTensorDesc*") + .func_arg("n_outputs", "int") + .func_start() + .line("assert(n_outputs == " + std::to_string(plugin->outputs.size()) + ");"); + CodegenOutputInfer(plugin, true); + stack_.func_end(); + // supportsFormatCombination + stack_.func_def(op_dynamic + "::supportsFormatCombination", "bool") + .func_decorator("noexcept") + .func_arg("pos", "int") + .func_arg("io_desc", "const PluginTensorDesc*") + .func_arg("n_inputs", "int") + .func_arg("n_outputs", "int") + .func_start() + .declare("bool", "support"); + size_t cnt = 0; + for (const auto& dtypes : GetDtypeMatrix(plugin)) { + String cond; + for (size_t i = 0; i < plugin->inputs.size(); i++) { + cond = cond + "io_desc[" + std::to_string(i) + "].type == TRTUtils::ToDataType(\"" + + dtypes.at(i) + "\")"; + cond = cond + (i == plugin->inputs.size() - 1 ? "" : " && "); + } + if (cnt == 0) { + stack_.switch_start(cond); + } else { + stack_.switch_case(cond); + } + stack_.assign("support", true); + cnt++; + } + stack_.switch_case().assign("support", false).switch_end().func_end("support"); + // getWorkspaceSize + stack_.func_def(op_dynamic + "::getWorkspaceSize", "size_t") + .func_decorator("const noexcept") + .func_arg("in_descs", "const PluginTensorDesc*") + .func_arg("n_inputs", "int") + .func_arg("out_descs", "const PluginTensorDesc*") + .func_arg("n_outputs", "int") + .func_start() + .assign("size", 0, "size_t"); + if (plugin->externs.count("infer_buffer")) { + CodegenBufferInfer(plugin); + } + stack_.func_end("size"); + // enqueue + stack_.func_def(op_dynamic + "::enqueue", "int") + .func_decorator("noexcept") + .func_arg("input_descs", "const PluginTensorDesc*") + .func_arg("output_descs", "const PluginTensorDesc*") + .func_arg("inputs", "const void* const*") + .func_arg("outputs", "void* const*") + .func_arg("workspace", "void*") + .func_arg("stream", "cudaStream_t") + .func_start(); + CodegenEnqueue(plugin, true); + stack_.func_end(0); + + // dynamic creator + CodegenCreator(plugin, true, false); +} + +void TensorRTPluginCodeGen::CodeGenCmake(const std::set& devices) { + Map flags; + flags.Set("PLUGIN_SUPPORT_TENSORRT", ""); + flags.Set("TRT_MAJOR", std::to_string(config()->version[0])); + flags.Set("TRT_MINOR", std::to_string(config()->version[1])); + flags.Set("TRT_PATCH", std::to_string(config()->version[2])); + CodeGenPreCmake(devices, flags); + stack_ + .line("find_path(TRT_INCLUDE_DIR NvInfer.h HINTS " + config()->tensorrt_root + + " PATH_SUFFIXES include)") + .line("find_library(TRT_LIBS nvinfer HINTS " + config()->tensorrt_root + + " PATH_SUFFIXES lib)") + .line("set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -Wno-terminate\")"); + Array includes, libs; + includes.push_back("${TRT_INCLUDE_DIR}"); + libs.push_back("${TRT_LIBS}"); + CodeGenPostCmake(devices, includes, libs); +} + +void TensorRTPluginCodeGen::CodeGenManagerMethods() { + BasePluginCodeGen::CodeGenManagerMethods(); + stack_.func_def("setup") + .func_arg("self", "object") + .func_start() + .for_start("lib", "os.listdir(self._lib_folder)") + .assign("lib_file", "os.path.join(self._lib_folder, lib)") + .func_call("CDLL", "", "ctypes") + .call_arg("lib_file") + .for_end() + .func_end(); +} + +void TensorRTPluginCodeGen::CodegenOpCommonMethods(const Plugin& plugin, bool dynamic, + bool in_declare) { + const auto& op_cls = OpCls(plugin, dynamic); + const String& plugin_cls = dynamic ? "IPluginV2DynamicExt" : "IPluginV2"; + if (in_declare) { + stack_.comment("common methods for " + op_cls); + stack_.constructor_def(op_cls).constructor_arg("name", "const std::string&"); + for (const auto& a : plugin->attrs) { + stack_.constructor_arg(a->name, "const " + ToCppType(a->type) + "&"); + } + stack_.constructor_arg("layouts", "const std::vector&") + .constructor_def(op_cls) + .constructor_arg("name", "const std::string&") + .constructor_arg("buffer", "const void*") + .constructor_arg("length", "size_t") + .assign(op_cls + "()", "delete") + .line() + .constructor_def("~" + op_cls) + .func_def("getSerializationSize", "size_t") + .func_decorator("const noexcept override") + .func_def("serialize") + .func_decorator("const noexcept override") + .func_arg("buffer", "void*") + .func_def("getPluginType", "const char*") + .func_decorator("const noexcept override") + .func_def("getPluginVersion", "const char*") + .func_decorator("const noexcept override") + .func_def("getPluginNamespace", "const char*") + .func_decorator("const noexcept override") + .func_def("getNbOutputs", "int") + .func_decorator("const noexcept override") + .func_def("setPluginNamespace") + .func_decorator("noexcept override") + .func_arg("name_space", "const char*") + .func_def("initialize", "int") + .func_decorator("noexcept override") + .func_def("terminate") + .func_decorator("noexcept override") + .func_def("destroy") + .func_decorator("noexcept override") + .func_def("clone", plugin_cls + "*") + .func_decorator("const noexcept override"); + } else { + const auto& attr_name = MetaAttrCls(plugin); + // constructor from attrs + stack_.constructor_def(op_cls + "::" + op_cls).constructor_arg("name", "const std::string&"); + for (const auto& a : plugin->attrs) { + stack_.constructor_arg(a->name, "const " + ToCppType(a->type) + "&"); + } + stack_.constructor_arg("layouts", "const std::vector&") + .constructor_start() + .assign("name_", "name"); + for (const auto& a : plugin->attrs) { + stack_.assign(DocUtils::ToAttrAccess("meta_attr_", a->name), a->name); + } + stack_.line("assert(layouts.size() == " + std::to_string(plugin->inputs.size()) + ");") + .assign("layouts_", "layouts"); + stack_.constructor_end(); + // constructor from data + stack_.constructor_def(op_cls + "::" + op_cls) + .constructor_arg("name", "const std::string&") + .constructor_arg("buffer", "const void*") + .constructor_arg("length", "size_t") + .constructor_start() + .assign("name_", "name") + .func_call("static_cast", DocUtils::ToDeclare("const char*", "char_buf")) + .call_arg("buffer") + .assign("start_buf", "char_buf", "const char*") + .func_call(attr_name + "_deserialize", "char_buf") + .call_arg("meta_attr_") + .call_arg("char_buf") + .func_call("TRTUtils::ValFromBuffer") + .call_arg("char_buf") + .call_arg("dtype_") + .func_call("TRTUtils::ValFromBuffer") + .call_arg("char_buf") + .call_arg("layouts_") + .line("assert(layouts_.size() == " + std::to_string(plugin->inputs.size()) + ");") + .line("assert(char_buf == (start_buf + length));") + .constructor_end(); + // deconstructor + stack_.constructor_def(op_cls + "::~" + op_cls) + .constructor_start() + .comment("ignore deconstruct of " + op_cls) + .constructor_end(); + // getSerializationSize + stack_.func_def(op_cls + "::getSerializationSize", "size_t") + .func_decorator("const noexcept") + .func_start() + .assign("size", attr_name + "_serialize_size()", "size_t") + .assign("size", "size + sizeof(dtype_)") + .assign("size", "size + sizeof(size_t)") + .for_start("layout", "layouts_") + .assign("size", "size + sizeof(size_t) + layout.size() * sizeof(char)") + .for_end() + .func_end("size"); + // serialize + stack_.func_def(op_cls + "::serialize") + .func_decorator("const noexcept") + .func_arg("buffer", "void*") + .func_start() + .func_call("static_cast", DocUtils::ToDeclare("char*", "char_buf")) + .call_arg("buffer") + .assign("start_buf", "char_buf", "const char*") + .func_call(attr_name + "_serialize", "char_buf") + .call_arg("meta_attr_") + .call_arg("char_buf") + .func_call("TRTUtils::ValToBuffer") + .call_arg("char_buf") + .call_arg("dtype_") + .func_call("TRTUtils::ValToBuffer") + .call_arg("char_buf") + .call_arg("layouts_") + .line("assert(char_buf == (start_buf + getSerializationSize()));") + .func_end(); + // getPluginType + const String& plugin_type = plugin->name + (dynamic ? "_dynamic" : ""); + stack_.func_def(op_cls + "::getPluginType", "const char*") + .func_decorator("const noexcept") + .func_start() + .func_end(DocUtils::ToStr(plugin_type)); + // getPluginVersion + stack_.func_def(op_cls + "::getPluginVersion", "const char*") + .func_decorator("const noexcept") + .func_start() + .func_end(DocUtils::ToStr("1")); + // getPluginNamespace + stack_.func_def(op_cls + "::getPluginNamespace", "const char*") + .func_decorator("const noexcept") + .func_start() + .func_call("c_str", DocUtils::ToDeclare("const char*", "name"), + DocUtils::ToDoc("name_space_")) + .func_end("name"); + // getNbOutputs + stack_.func_def(op_cls + "::getNbOutputs", "int") + .func_decorator("const noexcept") + .func_start() + .func_end(plugin->outputs.size()); + // setPluginNamespace + stack_.func_def(op_cls + "::setPluginNamespace") + .func_decorator("noexcept") + .func_arg("name_space", "const char*") + .func_start() + .assign("name_space_", "name_space") + .func_end(); + // initialize + stack_.func_def(op_cls + "::initialize", "int") + .func_decorator("noexcept") + .func_start() + .func_end(0); + // terminate + stack_.func_def(op_cls + "::terminate") + .func_decorator("noexcept") + .func_start() + .comment("Ignore teminate for " + plugin->name) + .func_end(); + // destroy + stack_.func_def(op_cls + "::destroy") + .func_decorator("noexcept") + .func_start() + .line("delete this;") + .func_end(); + // clone + stack_.func_def(op_cls + "::clone", plugin_cls + "*") + .func_decorator("const noexcept") + .func_start() + .func_call("new " + op_cls, DocUtils::ToDeclare(plugin_cls + "*", "plugin")) + .call_arg("name_"); + for (const auto& a : plugin->attrs) { + stack_.call_arg(DocUtils::ToAttrAccess("meta_attr_", a->name)); + } + stack_.call_arg("layouts_").func_end("plugin"); + } +} + +void TensorRTPluginCodeGen::CodegenOpMembers(const Plugin& plugin, bool dynamic) { + stack_.scope_start("private:") + .declare("std::string", "name_") + .declare("std::string", "name_space_") + .declare("DataType", "dtype_", 0, false) + .declare_arg("DataType::kFLOAT") + .declare(MetaAttrCls(plugin), "meta_attr_") + .declare("std::vector", "layouts_") + .declare("std::vector", "input_metas_") + .declare("std::vector", "output_metas_"); + if (plugin->externs.count("infer_buffer")) { + stack_.declare("std::vector", "buffer_metas_"); + } + stack_.scope_end().line(); +} + +void TensorRTPluginCodeGen::CodegenCreator(const Plugin& plugin, bool dynamic, bool in_declare) { + const auto& creator_cls = CreatorCls(plugin, dynamic); + const String& plugin_cls = dynamic ? "IPluginV2DynamicExt" : "IPluginV2"; + if (in_declare) { + stack_.class_def(creator_cls + " : public IPluginCreator") + .class_start() + .scope_start("public:") + .constructor_def(creator_cls) + .func_def("getPluginName", "const char*") + .func_decorator("const noexcept override") + .func_def("getPluginVersion", "const char*") + .func_decorator("const noexcept override") + .func_def("getPluginNamespace", "const char*") + .func_decorator("const noexcept override") + .func_def("getFieldNames", "const PluginFieldCollection*") + .func_decorator("noexcept override") + .func_def("setPluginNamespace") + .func_decorator("noexcept override") + .func_arg("name_space", "const char*") + .func_def("createPlugin", plugin_cls + "*") + .func_decorator("noexcept override") + .func_arg("name", "const char*") + .func_arg("collection", "const PluginFieldCollection*") + .func_def("deserializePlugin", plugin_cls + "*") + .func_decorator("noexcept override") + .func_arg("name", "const char*") + .func_arg("data", "const void*") + .func_arg("length", "size_t") + .scope_end() + .scope_start("private:") + .declare("static PluginFieldCollection", "collection_") + .declare("static std::vector", "fields_") + .declare("std::string", "name_space_") + .scope_end() + .line() + .class_end(); + } else { + const String& attr_name = MetaAttrCls(plugin); + // static members + stack_.comment("static members and register for " + plugin->name) + .declare("PluginFieldCollection", creator_cls + "::collection_") + .declare("std::vector", creator_cls + "::fields_") + .func_call("REGISTER_TENSORRT_PLUGIN") + .call_arg(creator_cls) + .line(); + // constructor + stack_.constructor_def(creator_cls + "::" + creator_cls) + .constructor_start() + .func_call(attr_name + "_to_fields") + .call_arg("fields_"); + for (const auto& t : plugin->inputs) { + stack_.func_call("emplace_back", "", "fields_") + .inplace_start("TRTUtils::ToField") + .call_arg(DocUtils::ToStr("layout_" + t->name)) + .call_arg(DocUtils::ToStr("string")) + .inplace_end(); + } + const auto& nb_fields_doc = DocUtils::ToAttrAccess("collection_", "nbFields"); + const auto& fields_doc = DocUtils::ToAttrAccess("collection_", "fields"); + stack_.func_call("size", nb_fields_doc, DocUtils::ToDoc("fields_")) + .func_call("data", fields_doc, DocUtils::ToDoc("fields_")) + .constructor_end(); + // getPluginName + const String& plugin_type = plugin->name + (dynamic ? "_dynamic" : ""); + stack_.func_def(creator_cls + "::getPluginName", "const char*") + .func_decorator("const noexcept") + .func_start() + .func_end(DocUtils::ToStr(plugin_type)); + // getPluginVersion + stack_.func_def(creator_cls + "::getPluginVersion", "const char*") + .func_decorator("const noexcept") + .func_start() + .func_end(DocUtils::ToStr("1")); + // getPluginNamespace + stack_.func_def(creator_cls + "::getPluginNamespace", "const char*") + .func_decorator("const noexcept") + .func_start() + .func_call("c_str", DocUtils::ToDeclare("const char*", "name"), + DocUtils::ToDoc("name_space_")) + .func_end("name"); + // getFieldNames + stack_.func_def(creator_cls + "::getFieldNames", "const PluginFieldCollection*") + .func_decorator("noexcept") + .func_start() + .func_end("&collection_"); + // setPluginNamespace + stack_.func_def(creator_cls + "::setPluginNamespace") + .func_decorator("noexcept") + .func_arg("name_space", "const char*") + .func_start() + .assign("name_space_", "name_space") + .func_end(); + // createPlugin + size_t fields_size = plugin->attrs.size() + plugin->inputs.size(); + const auto& op_cls = OpCls(plugin, dynamic); + stack_.func_def(creator_cls + "::createPlugin", plugin_cls + "*") + .func_decorator("noexcept") + .func_arg("name", "const char*") + .func_arg("collection", "const PluginFieldCollection*") + .func_start() + .line("assert(collection->nbFields == " + std::to_string(fields_size) + ");") + .assign("fields", DocUtils::ToAttrAccess(DocUtils::ToPtr("collection"), "fields"), + "const PluginField*") + .func_call(attr_name + "_from_fields", DocUtils::ToDeclare("const auto&", "meta_attr")) + .call_arg("fields") + .declare("std::vector", "layouts") + .func_call("resize", "", "layouts") + .call_arg(plugin->inputs.size()) + .for_start("i", plugin->attrs.size(), fields_size); + for (size_t i = 0; i < plugin->inputs.size(); i++) { + const auto& tensor = plugin->inputs[i]; + const String& cond = "strcmp(fields[i].name, \"layout_" + tensor->name + "\") == 0"; + if (i == 0) { + stack_.switch_start(cond); + } else { + stack_.switch_case(cond); + } + stack_.func_call("TRTUtils::FromField") + .call_arg(DocUtils::ToIndex("fields", "i")) + .call_arg(DocUtils::ToIndex("layouts", i)); + } + stack_.switch_end() + .for_end() + .func_call("new " + op_cls, DocUtils::ToDeclare(op_cls + "*", "plugin")) + .call_arg("name"); + for (const auto& a : plugin->attrs) { + stack_.call_arg(DocUtils::ToAttrAccess("meta_attr", a->name)); + } + stack_.call_arg("layouts") + .func_call("setPluginNamespace", NullOpt, DocUtils::ToPtr("plugin")) + .inplace_start("c_str", NullOpt, DocUtils::ToDoc("name_space_")) + .inplace_end() + .func_end("plugin"); + // deserializePlugin + stack_.func_def(creator_cls + "::deserializePlugin", plugin_cls + "*") + .func_decorator("noexcept") + .func_arg("name", "const char*") + .func_arg("data", "const void*") + .func_arg("length", "size_t") + .func_start() + .func_call("new " + op_cls, DocUtils::ToDeclare(op_cls + "*", "plugin")) + .call_arg("name") + .call_arg("data") + .call_arg("length") + .func_call("setPluginNamespace", NullOpt, DocUtils::ToPtr("plugin")) + .inplace_start("c_str", NullOpt, DocUtils::ToDoc("name_space_")) + .inplace_end() + .func_end("plugin"); + } +} + +void TensorRTPluginCodeGen::CodegenOutputInfer(const Plugin& plugin, bool as_desc) { + Array infer_args{"input_metas_", "meta_attr_", "false"}; + stack_.line("assert(n_inputs == " + std::to_string(plugin->inputs.size()) + ");") + .func_call("resize", "", "input_metas_") + .call_arg(plugin->inputs.size()) + .for_start("i", 0, plugin->inputs.size()) + .func_call("TRTUtils::ToMetaTensor", DocUtils::ToIndex("input_metas_", "i")); + if (as_desc) { + stack_.call_arg(DocUtils::ToIndex("in_descs", "i")); + } else { + stack_.call_arg(DocUtils::ToIndex("in_dims", "i")).call_arg("dtype_"); + } + stack_.call_arg(DocUtils::ToIndex("layouts_", "i")).for_end(); + CodeGenSafeCall(plugin->externs["infer_output"], infer_args, "output_metas_"); +} + +void TensorRTPluginCodeGen::CodegenBufferInfer(const Plugin& plugin) { + Array infer_args{"input_metas_", "meta_attr_", "false"}; + CodeGenSafeCall(plugin->externs["infer_buffer"], infer_args, "buffer_metas_"); + stack_.for_start("b", "buffer_metas_") + .assign("size", "size + max_batch * b.size(false)") + .for_end(); +} + +void TensorRTPluginCodeGen::CodegenEnqueue(const Plugin& plugin, bool dynamic) { + ICHECK(plugin->externs.count("cuda_compute")) << "cuda_compute is needed fo TensorRT plugin"; + auto prepare_tensor = [this, &dynamic](const PluginTensor& tensor, + const Map& dtypes, size_t idx, + const String& collect) { + const String& t_name = "d_" + tensor->name; + const String& t_dtype = dtypes.count(tensor->name) ? dtypes[tensor->name] : tensor->dtype; + const String& tensor_type = "DataTensor<" + t_dtype + ">"; + const String& anno = collect == "input" ? "const " + tensor_type + "&" : tensor_type; + stack_.func_call("TRTUtils::To" + tensor_type, DocUtils::ToDeclare(anno, t_name)); + const auto& t_meta = DocUtils::ToIndex(collect + "_metas_", idx); + if (dynamic) { + stack_.call_arg(t_meta).call_arg(DocUtils::ToIndex(collect + "_descs", idx)); + } else { + stack_.call_arg(t_meta).call_arg("batch_size"); + } + if (collect == "input") { + stack_.call_arg(DocUtils::ToIndex("inputs", idx)); + } else if (collect == "output") { + stack_.call_arg(DocUtils::ToIndex("outputs", idx)); + } else { + stack_.call_arg("workspace + offset"); + } + return t_name; + }; + for (const auto& dtypes : GetDtypeMatrix(plugin)) { + const auto& tensor_dtypes = GetTensorDtypes(plugin, dtypes); + Array compute_args; + String dtype_cond = ""; + if (dynamic) { + for (size_t i = 0; i < plugin->inputs.size(); i++) { + dtype_cond = dtype_cond + "input_descs[" + std::to_string(i) + + "].type == TRTUtils::ToDataType(\"" + dtypes.at(i) + "\")"; + dtype_cond = dtype_cond + (i == plugin->inputs.size() - 1 ? "" : " && "); + } + } else { + dtype_cond = "dtype_ == TRTUtils::ToDataType(\"" + dtypes.at(0) + "\")"; + } + // prepare compute datas + stack_.cond_if(dtype_cond).comment("prepare compute datas"); + for (size_t i = 0; i < plugin->inputs.size(); i++) { + const String& t_name = prepare_tensor(plugin->inputs[i], tensor_dtypes, i, "input"); + compute_args.push_back(t_name); + } + for (size_t i = 0; i < plugin->outputs.size(); i++) { + const String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "output"); + compute_args.push_back(t_name); + } + if (plugin->buffers.size() > 0) { + stack_.assign("offset", 0, "size_t"); + for (size_t i = 0; i < plugin->buffers.size(); i++) { + const String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "buffer"); + compute_args.push_back(t_name); + const String& size_name = "size_" + plugin->buffers[i]->name; + stack_ + .func_call("size", DocUtils::ToDeclare("size_t", size_name), + DocUtils::ToIndex("buffer_metas_", i)) + .call_arg(false) + .assign("offset", "offset + batch_size * " + size_name); + } + } + compute_args.push_back("meta_attr_"); + compute_args.push_back("stream"); + CodeGenSafeCall(plugin->externs["cuda_compute"], compute_args); + stack_.cond_end(); + } +} + +TVM_REGISTER_GLOBAL("msc.plugin.GetTensorRTPluginSources") + .set_body_typed([](const String& codegen_config, const String& print_config, + const String& codegen_type) -> Map { + TensorRTPluginCodeGen codegen = TensorRTPluginCodeGen(codegen_config); + if (codegen_type == "build") { + return codegen.GetBuildSources(print_config); + } + if (codegen_type == "manager") { + return codegen.GetManagerSources(print_config); + } + return Map(); + }); + +} // namespace msc +} // namespace contrib +} // namespace tvm diff --git a/src/contrib/msc/plugin/tensorrt_codegen.h b/src/contrib/msc/plugin/tensorrt_codegen.h new file mode 100644 index 000000000000..24fb4e5dfca2 --- /dev/null +++ b/src/contrib/msc/plugin/tensorrt_codegen.h @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/plugin/tensorrt_codegen.h + * \brief Codegen for tensorrt plugin. + */ +#ifndef TVM_CONTRIB_MSC_PLUGIN_TENSORRT_CODEGEN_H_ +#define TVM_CONTRIB_MSC_PLUGIN_TENSORRT_CODEGEN_H_ + +#include +#include + +#include "base_codegen.h" +#include "codegen_utils.h" + +namespace tvm { +namespace contrib { +namespace msc { + +/*! + * \brief CodeGen config for tensorrt plugin + */ +struct TensorRTPluginCodeGenConfig { + std::string tensorrt_root{"/usr/local/cuda"}; + PLUGIN_CODEGEN_CONFIG_MEMBERS + void Load(dmlc::JSONReader* reader) { + std::string key; + reader->BeginObject(); + while (reader->NextObjectItem(&key)) { + if (key == "tensorrt_root") { + reader->Read(&tensorrt_root); + } else { + PLUGIN_CODEGEN_CONFIG_PARSE + } + } + } +}; + +class TensorRTPluginCodeGen : public BasePluginCodeGen { + public: + /*! + * \brief The constructor of TensorRTPluginCodeGen + * \param config the options for codegen. + */ + explicit TensorRTPluginCodeGen(const std::string& config = "") + : BasePluginCodeGen(config) {} + + protected: + /*! \brief Codegen plugin attr declare*/ + void CodeGenAttrDeclare(const Plugin& plugin) final; + + /*! \brief Codegen plugin attr define*/ + void CodeGenAttrDefine(const Plugin& plugin) final; + + /*! \brief Header of plugin files*/ + void CodeGenOpHeader(const Plugin& plugin) final; + + /*! \brief Codegen plugin op declare*/ + void CodeGenOpDeclare(const Plugin& plugin) final; + + /*! \brief Codegen plugin op define*/ + void CodeGenOpDefine(const Plugin& plugin) final; + + /*! \brief Codegen cmake file*/ + void CodeGenCmake(const std::set& devices) final; + + /*! \brief Codegen manager methods*/ + void CodeGenManagerMethods() final; + + private: + /*! \brief Op class name of plugin*/ + const String OpCls(const Plugin& plugin, bool dynamic) const { + return plugin->name + (dynamic ? "DynamicPlugin" : "Plugin"); + } + + /*! \brief Creator class name of plugin*/ + const String CreatorCls(const Plugin& plugin, bool dynamic) const { + return plugin->name + (dynamic ? "DynamicCreator" : "Creator"); + } + + bool IsMixPrecision(const Plugin& plugin) { + for (const auto& dtypes : GetDtypeMatrix(plugin)) { + String ref_dtype = ""; + for (const auto& pair : dtypes) { + if (ref_dtype.size() == 0) { + ref_dtype = pair.second; + } else if (ref_dtype != pair.second) { + return true; + } + } + } + return false; + } + + /*! \brief codegen plugin op common methods declare*/ + void CodegenOpCommonMethods(const Plugin& plugin, bool dynamic, bool in_declare); + + /*! \brief codegen plugin op members define*/ + void CodegenOpMembers(const Plugin& plugin, bool dynamic); + + /*! \brief codegen plugin creator*/ + void CodegenCreator(const Plugin& plugin, bool dynamic, bool in_declare); + + /*! \brief codegen infer output func*/ + void CodegenOutputInfer(const Plugin& plugin, bool as_desc = false); + + /*! \brief codegen infer buffer func*/ + void CodegenBufferInfer(const Plugin& plugin); + + /*! \brief codegen enqueue func*/ + void CodegenEnqueue(const Plugin& plugin, bool dynamic); +}; + +} // namespace msc +} // namespace contrib +} // namespace tvm +#endif // TVM_CONTRIB_MSC_PLUGIN_TENSORRT_CODEGEN_H_ diff --git a/src/contrib/msc/plugin/torch_codegen.cc b/src/contrib/msc/plugin/torch_codegen.cc new file mode 100644 index 000000000000..4b8c24f17bbb --- /dev/null +++ b/src/contrib/msc/plugin/torch_codegen.cc @@ -0,0 +1,510 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/plugin/torch_codegen.cc + */ +#include "torch_codegen.h" + +namespace tvm { +namespace contrib { +namespace msc { + +void TorchPluginCodeGen::CodeGenAttrDeclare(const Plugin& plugin) { + BasePluginCodeGen::CodeGenAttrDeclare(plugin); + const auto& attr_name = MetaAttrCls(plugin); + // serialize method for attr + stack_.comment("serialize method") + .func_def(attr_name + "_serialize", "std::vector") + .func_arg("meta_attr", "const " + attr_name + "&"); + // deserialize method for attr + stack_.comment("deserialize method") + .func_def(attr_name + "_deserialize") + .func_arg("attrs", "const std::vector&") + .func_arg("meta_attr", attr_name + "&"); +} + +void TorchPluginCodeGen::CodeGenAttrDefine(const Plugin& plugin) { + const auto& attr_name = MetaAttrCls(plugin); + // serialize method for attr + stack_.func_def(attr_name + "_serialize", "std::vector") + .func_arg("meta_attr", "const " + attr_name + "&") + .func_start() + .declare("std::vector", "attrs"); + for (const auto& a : plugin->attrs) { + stack_.func_call("push_back", "", "attrs") + .inplace_start("SerializeUtils::ToString") + .call_arg(DocUtils::ToAttrAccess("meta_attr", a->name)) + .inplace_end(); + } + stack_.func_end("attrs"); + // deserialize method for attr + stack_.func_def(attr_name + "_deserialize") + .func_arg("attrs", "const std::vector&") + .func_arg("meta_attr", attr_name + "&") + .func_start(); + for (size_t i = 0; i < plugin->attrs.size(); i++) { + stack_.func_call("SerializeUtils::FromString") + .call_arg(DocUtils::ToIndex("attrs", i)) + .call_arg(DocUtils::ToAttrAccess("meta_attr", plugin->attrs[i]->name)); + } + stack_.func_end(); +} + +void TorchPluginCodeGen::CodeGenOpDeclare(const Plugin& plugin) { + stack_.struct_start(plugin->name + " : torch::CustomClassHolder"); + // constructor + stack_.constructor_def(plugin->name).constructor_arg("attrs", "const std::vector&"); + // serialize method + stack_.comment("serialize method").func_def("serialize", "const std::vector"); + // compute method + stack_.comment("main compute") + .func_def("compute", "std::vector") + .func_arg("input_tensors", "const std::vector&"); + // members + stack_.comment("members") + .declare(MetaAttrCls(plugin), "meta_attr_") + .declare("std::vector", "layouts_") + .declare("std::string", "name_"); + stack_.struct_end(); + // entry method + stack_.comment("Entry method for plugin " + plugin->name) + .func_def(EntryName(plugin), "std::vector") + .func_arg("instance", "const c10::intrusive_ptr<" + plugin->name + ">&"); + for (const auto& input : plugin->inputs) { + stack_.func_arg(input->name, "const torch::Tensor&"); + } + for (const auto& a : plugin->attrs) { + stack_.func_arg(a->name, "const " + ToTorchType(a->type) + "&"); + } + stack_.func_arg("name", "const std::string&"); +} + +void TorchPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { + const auto& attr_name = MetaAttrCls(plugin); + // define constructor + stack_.constructor_def(plugin->name + "::" + plugin->name) + .constructor_arg("attrs", "const std::vector&") + .constructor_start() + .comment("get attributes") + .func_call(attr_name + "_deserialize") + .call_arg("attrs") + .call_arg("meta_attr_") + .comment("get extra info") + .assign("name_", DocUtils::ToIndex("attrs", plugin->attrs.size())) + .for_start("i", 1 + plugin->attrs.size(), 1 + plugin->attrs.size() + plugin->inputs.size()) + .func_call("push_back", "", "layouts_") + .inplace_start("MetaLayout") + .call_arg(DocUtils::ToIndex("attrs", "i")) + .inplace_end() + .for_end() + .constructor_end(); + // define serialize + stack_.func_def(plugin->name + "::serialize", "const std::vector") + .func_start() + .assign("attrs", attr_name + "_serialize(meta_attr_)", "std::vector") + .func_call("push_back", "", "attrs") + .call_arg("name_") + .for_start("i", 0, plugin->inputs.size()) + .func_call("push_back", "", "attrs") + .call_arg(DocUtils::ToAttrAccess(DocUtils::ToIndex("layouts_", "i"), "name()")) + .for_end() + .func_end("attrs"); + // compute method + stack_.func_def(plugin->name + "::compute", "std::vector") + .func_arg("input_tensors", "const std::vector&") + .func_start() + .declare("std::vector", "output_tensors"); + if (plugin->externs.count("infer_buffer")) { + stack_.declare("std::vector", "buffer_tensors"); + } + stack_.line() + .comment("extract meta inputs") + .declare("std::vector", "input_metas") + .for_start("i", 0, plugin->inputs.size()) + .func_call("push_back", "", "input_metas") + .inplace_start("TorchUtils::ToMetaTensor") + .call_arg(DocUtils::ToIndex("input_tensors", "i")) + .call_arg(DocUtils::ToIndex("layouts_", "i")) + .inplace_end() + .for_end(); + // malloc outputs and buffers + ICHECK(plugin->externs.count("infer_output")) << "Can not find extern shape"; + CodeGenMalloc(plugin, plugin->outputs, "output"); + if (plugin->externs.count("infer_buffer")) { + CodeGenMalloc(plugin, plugin->buffers, "buffer"); + } + // do the compute + String device_cond = ""; + for (size_t i = 0; i < plugin->inputs.size(); i++) { + if (plugin->inputs[i]->device == "cuda" || plugin->inputs[i]->device == "default") { + device_cond = device_cond + "input_tensors[" + std::to_string(i) + "].is_cuda()"; + } else { + device_cond = device_cond + "!input_tensors[" + std::to_string(i) + "].is_cuda()"; + } + device_cond = device_cond + (i == plugin->inputs.size() - 1 ? "" : " && "); + } + stack_.line().comment("do the compute").cond_if(device_cond); + CodeGenCompute(plugin, "cuda"); + stack_.cond_else(); + CodeGenCompute(plugin, "cpu"); + stack_.cond_end(); + stack_.func_end("output_tensors"); + + // register op + const auto& entry_name = EntryName(plugin); + stack_.func_def(entry_name, "std::vector") + .func_arg("instance", "const c10::intrusive_ptr<" + plugin->name + ">&"); + for (const auto& input : plugin->inputs) { + stack_.func_arg(input->name, "const torch::Tensor&"); + } + for (const auto& a : plugin->attrs) { + stack_.func_arg(a->name, "const " + ToTorchType(a->type) + "&"); + } + stack_.func_arg("name", "const std::string&"); + stack_.func_start().declare("std::vector", "inputs", 0, false); + for (const auto& input : plugin->inputs) { + stack_.declare_arg(input->name); + } + const auto& outputs_doc = DocUtils::ToDeclare("std::vector", "outputs"); + stack_.func_call("compute", outputs_doc, DocUtils::ToPtr("instance")).call_arg("inputs"); + stack_.func_end("outputs"); + stack_.comment("Bind plugin " + plugin->name + " to python") + .func_def("TORCH_LIBRARY", DocSymbol::Empty()) + .func_arg(plugin->name, DocSymbol::Empty()) + .func_arg("m", DocSymbol::Empty()) + .func_start() + .lambda_def("serialize") + .lambda_arg("op", "const c10::intrusive_ptr<" + plugin->name + ">&") + .lambda_start() + .lambda_end(DocUtils::ToAttrAccess(DocUtils::ToPtr("op"), "serialize()")) + .lambda_def("deserialize") + .lambda_arg("state", "std::vector") + .lambda_start() + .lambda_end("c10::make_intrusive<" + plugin->name + ">(std::move(state))") + .func_call("class_<" + plugin->name + ">", "", "m") + .call_arg(DocUtils::ToStr(plugin->name)) + .method_call("def", true) + .call_arg("torch::init>()") + .method_call("def", true) + .call_arg(DocUtils::ToStr("compute")) + .call_arg("&" + plugin->name + "::compute") + .method_call("def_pickle", true) + .call_arg("serialize") + .call_arg("deserialize") + .func_call("def", "", "m") + .call_arg(DocUtils::ToStr(entry_name)) + .call_arg(entry_name) + .func_end(); +} + +void TorchPluginCodeGen::CodeGenCmake(const std::set& devices) { + Map flags; + flags.Set("PLUGIN_SUPPORT_TORCH", ""); + CodeGenPreCmake(devices, flags); + stack_.line() + .line("set(CMAKE_CXX_STANDARD 14)") + .line("list(APPEND CMAKE_PREFIX_PATH \"" + config()->torch_prefix + "\")") + .line("find_package(Torch REQUIRED)"); + Array includes, libs; + libs.push_back("${TORCH_LIBRARIES}"); + CodeGenPostCmake(devices, includes, libs); +} + +void TorchPluginCodeGen::CodeGenManagerDepends() { + BasePluginCodeGen::CodeGenManagerDepends(); + stack_.line("import torch") + .line() + .func_def("to_string", "str") + .func_arg("value", "Any") + .func_start() + .switch_start("isinstance(value, (list, tuple))") + .assign("str_value", "\",\".join([str(len(value))] + [to_string(v) for v in value])") + .switch_case("isinstance(value, bool)") + .assign("str_value", "\"1\" if value else \"0\"") + .switch_case() + .assign("str_value", "str(value)") + .switch_end() + .func_end("str_value"); +} + +void TorchPluginCodeGen::CodeGenManagerMethods() { + BasePluginCodeGen::CodeGenManagerMethods(); + // libs_loaded method + stack_.func_def("libs_loaded") + .func_arg("self", "object") + .func_start() + .assign("loaded_libs", "set()") + .assign("loaded", DocUtils::ToDoc(false)) + .for_start("lib", "torch.classes.loaded_libraries") + .func_call("add", "", "loaded_libs") + .inplace_start("os.path.basename") + .call_arg("lib") + .inplace_end() + .for_end() + .for_start("lib", "os.listdir(self._lib_folder)") + .cond_if("lib in loaded_libs") + .assign("loaded", DocUtils::ToDoc(true)) + .line("break") + .cond_end() + .for_end() + .func_end("loaded"); + // setup method + stack_.func_def("setup") + .func_arg("self", "object") + .func_start() + .for_start("lib", "os.listdir(self._lib_folder)") + .assign("lib_file", "os.path.join(self._lib_folder, lib)") + .cond_if("\"" + config()->project_name + "\" in lib") + .func_call("load_library", "", "torch.classes") + .call_arg("lib_file") + .cond_else() + .func_call("CDLL", "", "ctypes") + .call_arg("lib_file") + .cond_end() + .for_end() + .func_end(); +} + +void TorchPluginCodeGen::CodeGenOpBuilder(const Plugin& plugin) { + const auto& entry_name = EntryName(plugin); + stack_.func_def(plugin->name).func_arg("self", "object"); + for (const auto& attr : plugin->attrs) { + stack_.func_arg(attr->name, attr->type, attr->default_value); + } + stack_.func_arg("name", "str", "\"" + plugin->name + "\"") + .func_arg("layouts", "List[str]", "None") + .func_start() + .class_def(plugin->name + "(torch.nn.Module)") + .class_start(); + // init method + stack_.func_def("__init__").func_arg("self", "torch.nn.Module"); + for (const auto& attr : plugin->attrs) { + stack_.func_arg(attr->name, attr->type, attr->default_value); + } + stack_.func_arg("name", "str", "\"" + plugin->name + "\"") + .func_arg("layouts", "List[str]", "None") + .func_start() + .func_call("__init__", "", "super()"); + for (const auto& attr : plugin->attrs) { + stack_.assign(DocUtils::ToAttrAccess("self", attr->name), attr->name); + } + stack_.assign(DocUtils::ToAttrAccess("self", "name"), "name") + .cond_if("layouts is None") + .assign(DocUtils::ToAttrAccess("self", "layouts"), + "[\"\"] * " + std::to_string(plugin->inputs.size())) + .cond_else() + .assign(DocUtils::ToAttrAccess("self", "layouts"), "layouts") + .cond_end() + .line() + .assign("attr_strs", "[]"); + for (const auto& attr : plugin->attrs) { + stack_.func_call("append", "", "attr_strs") + .inplace_start("to_string") + .call_arg(attr->name) + .inplace_end(); + } + stack_.func_call("append", "", "attr_strs") + .call_arg("name") + .func_call("extend", "", "attr_strs") + .call_arg(DocUtils::ToAttrAccess("self", "layouts")) + .line() + .func_call(plugin->name + "." + plugin->name, "self._inner_class", "torch.classes") + .call_arg("attr_strs") + .func_end(); + // forward method + stack_.func_def("forward", "List[torch.Tensor]").func_arg("self", "torch.nn.Module"); + for (const auto& t : plugin->inputs) { + stack_.func_arg(t->name, "torch.Tensor"); + } + stack_.func_start() + .func_call(plugin->name + "." + entry_name, "outputs", "torch.ops") + .call_arg("self._inner_class"); + for (const auto& t : plugin->inputs) { + stack_.call_arg(t->name); + } + for (const auto& a : plugin->attrs) { + stack_.call_arg(DocUtils::ToAttrAccess("self", a->name)); + } + stack_.call_arg(DocUtils::ToAttrAccess("self", "name")); + if (plugin->outputs.size() == 1) { + stack_.func_end(DocUtils::ToIndex("outputs", 0)); + } else { + stack_.func_end("outputs"); + } + // end of inner class + stack_.class_end(); + stack_.func_call(plugin->name, "op"); + for (const auto& attr : plugin->attrs) { + stack_.call_arg(attr->name); + } + stack_.call_arg("name").call_arg("layouts").func_end("op").comment(GetPyComment(plugin), true); +} + +void TorchPluginCodeGen::CodeGenConvertDepends() { + BasePluginCodeGen::CodeGenConvertDepends(); + stack_.line("from torch import fx") + .line("from tvm.relax.frontend.torch.fx_translator import TorchFXImporter") + .line(); +} + +const String TorchPluginCodeGen::CodeGenOpConvert(const Plugin& plugin) { + stack_.func_def(ConverterName(plugin), "relax.Var") + .func_arg("node", "fx.node.Node") + .func_arg("ctx", "TorchFXImporter") + .func_start() + .func_call("retrieve_args", "args", "ctx") + .call_arg("node"); + Array args; + for (size_t i = 0; i < plugin->inputs.size(); i++) { + const auto& tensor = plugin->inputs[i]; + stack_.assign(tensor->name, DocUtils::ToIndex("args", i + 1)); + args.push_back(tensor->name); + } + for (size_t i = 0; i < plugin->attrs.size(); i++) { + const auto& attr = plugin->attrs[i]; + stack_.func_call("plugin_utils.to_expr", attr->name) + .call_arg(DocUtils::ToIndex("args", i + plugin->inputs.size() + 1)); + args.push_back(attr->name); + } + stack_.assign("name", + DocUtils::ToIndex("args", 1 + plugin->inputs.size() + plugin->attrs.size())); + stack_.func_call("relax.Tuple", "args") + .call_arg(DocUtils::ToList(args)) + .func_call("InferStructInfo" + plugin->name, "out_sinfo", "_plugin_api"); + for (const auto& t : plugin->inputs) { + stack_.call_arg(t->name); + } + for (const auto& a : plugin->attrs) { + stack_.call_arg(a->name); + } + stack_.func_call("call_dps_packed", "op") + .call_arg(DocUtils::ToStr(plugin->name)) + .call_arg("args", "args") + .call_arg("list(out_sinfo)", "out_sinfo") + .func_call("msc_utils.set_expr_name", "op") + .call_arg("op") + .call_arg("name") + .func_call("emit", "var", "ctx.block_builder") + .call_arg("op") + .call_arg("name"); + if (plugin->outputs.size() == 1) { + stack_.func_end(DocUtils::ToList(Array{"var"})); + } else { + Array outputs; + for (size_t i = 0; i < plugin->outputs.size(); i++) { + const auto& tensor = plugin->outputs[i]; + stack_.func_call("relax.TupleGetItem", tensor->name).call_arg("var").call_arg(i); + outputs.push_back(tensor->name); + } + stack_.func_end(DocUtils::ToList(outputs)); + } + return EntryName(plugin); +} + +void TorchPluginCodeGen::CodeGenMalloc(const Plugin& plugin, const Array& tensors, + const String& collect) { + Array call_args{"input_metas", "meta_attr_", "true"}; + stack_.line().comment("malloc " + collect).declare("std::vector", collect + "_metas"); + CodeGenSafeCall(plugin->externs["infer_" + collect], call_args, collect + "_metas"); + for (size_t i = 0; i < tensors.size(); i++) { + stack_.func_call("push_back", "", collect + "_tensors") + .inplace_start("TorchUtils::MallocTorchTensor") + .call_arg(DocUtils::ToIndex(collect + "_metas", i)); + int device_idx = plugin->FindDeviceRefIdx(tensors[i]); + if (device_idx >= 0) { + const auto& input_doc = DocUtils::ToIndex("input_tensors", device_idx); + stack_.inplace_start("device", NullOpt, input_doc).inplace_end(); + } else { + stack_.inplace_start("TorchUtils::ToTorchDevice") + .call_arg(DocUtils::ToStr(tensors[i]->device)) + .inplace_end(); + } + stack_.inplace_end(); + } +} + +void TorchPluginCodeGen::CodeGenCompute(const Plugin& plugin, const String& device) { + auto prepare_tensor = [this](const PluginTensor& tensor, const Map& dtypes, + size_t idx, const String& collect) { + const String& t_name = "d_" + tensor->name; + const String& t_dtype = dtypes.count(tensor->name) ? dtypes[tensor->name] : tensor->dtype; + const String& tensor_type = "DataTensor<" + t_dtype + ">"; + const String& anno = collect == "input" ? "const " + tensor_type + "&" : tensor_type; + stack_.func_call("TorchUtils::To" + tensor_type, DocUtils::ToDeclare(anno, t_name)) + .call_arg(DocUtils::ToIndex(collect + "_tensors", idx)) + .call_arg(DocUtils::ToIndex(collect + "_metas", idx)) + .call_arg(collect == "input"); + return t_name; + }; + + if (plugin->externs.count(device + "_compute")) { + for (const auto& dtypes : GetDtypeMatrix(plugin)) { + const auto& tensor_dtypes = GetTensorDtypes(plugin, dtypes); + Array compute_args; + String dtype_cond = ""; + for (size_t i = 0; i < plugin->inputs.size(); i++) { + dtype_cond = dtype_cond + "input_metas[" + std::to_string(i) + + "].data_type() == DataUtils::ToMetaType(\"" + dtypes.at(i) + "\")"; + dtype_cond = dtype_cond + (i == plugin->inputs.size() - 1 ? "" : " && "); + } + // prepare compute datas + stack_.cond_if(dtype_cond).comment("prepare compute datas"); + for (size_t i = 0; i < plugin->inputs.size(); i++) { + const String& t_name = prepare_tensor(plugin->inputs[i], tensor_dtypes, i, "input"); + compute_args.push_back(t_name); + } + for (size_t i = 0; i < plugin->outputs.size(); i++) { + const String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "output"); + compute_args.push_back(t_name); + } + for (size_t i = 0; i < plugin->buffers.size(); i++) { + const String& t_name = prepare_tensor(plugin->buffers[i], tensor_dtypes, i, "buffer"); + compute_args.push_back(t_name); + } + compute_args.push_back("meta_attr_"); + if (device == "cuda") { + stack_.func_call("at::cuda::getCurrentCUDAStream", + DocUtils::ToDeclare("cudaStream_t", "stream")); + compute_args.push_back("stream"); + } + CodeGenSafeCall(plugin->externs[device + "_compute"], compute_args); + stack_.cond_end(); + } + } else { + stack_.comment("Skip compute on " + device); + } +} + +TVM_REGISTER_GLOBAL("msc.plugin.GetTorchPluginSources") + .set_body_typed([](const String& codegen_config, const String& print_config, + const String& codegen_type) -> Map { + TorchPluginCodeGen codegen = TorchPluginCodeGen(codegen_config); + if (codegen_type == "build") { + return codegen.GetBuildSources(print_config); + } + if (codegen_type == "manager") { + return codegen.GetManagerSources(print_config); + } + return Map(); + }); + +} // namespace msc +} // namespace contrib +} // namespace tvm diff --git a/src/contrib/msc/plugin/torch_codegen.h b/src/contrib/msc/plugin/torch_codegen.h new file mode 100644 index 000000000000..4452650e2271 --- /dev/null +++ b/src/contrib/msc/plugin/torch_codegen.h @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/plugin/torch_codegen.h + * \brief Codegen for torch plugin. + */ +#ifndef TVM_CONTRIB_MSC_PLUGIN_TORCH_CODEGEN_H_ +#define TVM_CONTRIB_MSC_PLUGIN_TORCH_CODEGEN_H_ + +#include +#include + +#include "base_codegen.h" +#include "codegen_utils.h" + +namespace tvm { +namespace contrib { +namespace msc { + +/*! + * \brief CodeGen config for torch plugin + */ +struct TorchPluginCodeGenConfig { + bool is_training{false}; + std::string torch_prefix{"torch"}; + PLUGIN_CODEGEN_CONFIG_MEMBERS + void Load(dmlc::JSONReader* reader) { + std::string key; + reader->BeginObject(); + while (reader->NextObjectItem(&key)) { + if (key == "is_training") { + reader->Read(&is_training); + } else if (key == "torch_prefix") { + reader->Read(&torch_prefix); + } else { + PLUGIN_CODEGEN_CONFIG_PARSE + } + } + } +}; + +class TorchPluginCodeGen : public BasePluginCodeGen { + public: + /*! + * \brief The constructor of TorchPluginCodeGen + * \param config the options for codegen. + */ + explicit TorchPluginCodeGen(const std::string& config = "") + : BasePluginCodeGen(config) {} + + protected: + /*! \brief Codegen plugin attr declare*/ + void CodeGenAttrDeclare(const Plugin& plugin) final; + + /*! \brief Codegen plugin attr define*/ + void CodeGenAttrDefine(const Plugin& plugin) final; + + /*! \brief Codegen plugin op declare*/ + void CodeGenOpDeclare(const Plugin& plugin) final; + + /*! \brief Codegen plugin op define*/ + void CodeGenOpDefine(const Plugin& plugin) final; + + /*! \brief Codegen cmake file*/ + void CodeGenCmake(const std::set& devices) final; + + /*! \brief Codegen manager depends*/ + void CodeGenManagerDepends() final; + + /*! \brief Codegen manager methods*/ + void CodeGenManagerMethods() final; + + /*! \brief Codegen manager member for plugin*/ + void CodeGenOpBuilder(const Plugin& plugin) final; + + /*! \brief Codegen convert depends*/ + void CodeGenConvertDepends() final; + + /*! \brief Codegen convert function for plugin*/ + const String CodeGenOpConvert(const Plugin& plugin) final; + + private: + /*! \brief Codegen malloc for outputs/buffers*/ + void CodeGenMalloc(const Plugin& plugin, const Array& tensors, + const String& collect); + + /*! \brief Codegen compute*/ + void CodeGenCompute(const Plugin& plugin, const String& device); + + /*! \brief Entry name of torch function*/ + const String EntryName(const Plugin& plugin) { + std::string lower_name; + const std::string& name = std::string(plugin->name); + for (size_t i = 0; i < name.size(); i++) { + const char& lower_c = tolower(name[i]); + if (lower_c != name[i] && i > 0) { + lower_name += "_"; + } + lower_name += lower_c; + } + return lower_name + "_entry"; + } + + /*! \brief Type name in torch*/ + const String ToTorchType(const String& type) { + if (type == "float") { + return "double"; + } + if (IsListType(type)) { + const auto& ele_type = GetEleType(type); + return "c10::arrayRef<" + ToTorchType(ele_type) + ">"; + } + return BasePluginCodeGen::ToCppType(type); + } +}; + +} // namespace msc +} // namespace contrib +} // namespace tvm +#endif // TVM_CONTRIB_MSC_PLUGIN_TORCH_CODEGEN_H_ diff --git a/src/contrib/msc/plugin/tvm_codegen.cc b/src/contrib/msc/plugin/tvm_codegen.cc new file mode 100644 index 000000000000..08a62c53bf0a --- /dev/null +++ b/src/contrib/msc/plugin/tvm_codegen.cc @@ -0,0 +1,411 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/plugin/tvm_codegen.cc + */ +#include "tvm_codegen.h" + +namespace tvm { +namespace contrib { +namespace msc { + +void TVMPluginCodeGen::CodeGenAttrDeclare(const Plugin& plugin) { + BasePluginCodeGen::CodeGenAttrDeclare(plugin); + const auto& attr_name = MetaAttrCls(plugin); + // exprs to meta_attr + stack_.comment("convert exprs to meta attrs method") + .func_def(attr_name + "_from_exprs", "const " + attr_name); + for (const auto& a : plugin->attrs) { + const String& anno = IsListType(a->type) ? "Tuple" : "PrimValue"; + stack_.func_arg(a->name, "const " + anno + "&"); + } + // args to meta_attr + stack_.comment("convert args to meta attrs method") + .func_def(attr_name + "_from_args", "const " + attr_name) + .func_arg("args", "TVMArgs") + .func_arg("pos", "size_t&"); +} + +void TVMPluginCodeGen::CodeGenAttrDefine(const Plugin& plugin) { + const auto& attr_name = MetaAttrCls(plugin); + // exprs to meta_attr + stack_.func_def(attr_name + "_from_exprs", "const " + attr_name); + for (const auto& a : plugin->attrs) { + const String& anno = IsListType(a->type) ? "Tuple" : "PrimValue"; + stack_.func_arg(a->name, "const " + anno + "&"); + } + stack_.func_start().declare(attr_name, "meta_attr"); + for (const auto& a : plugin->attrs) { + const String& convert = IsListType(a->type) ? "AttrFromPrims" : "AttrFromPrim"; + stack_.func_call("TVMUtils::" + convert) + .call_arg(a->name) + .call_arg(DocUtils::ToAttrAccess("meta_attr", a->name)); + } + stack_.func_end("meta_attr"); + // args to meta_attr + stack_.comment("convert args to meta attrs method") + .func_def(attr_name + "_from_args", "const " + attr_name) + .func_arg("args", "TVMArgs") + .func_arg("pos", "size_t&") + .func_start() + .declare(attr_name, "meta_attr"); + for (const auto& a : plugin->attrs) { + if (IsListType(a->type)) { + // TODO(meng.tong): support list atribute + LOG_FATAL << "ListType argument is not supported for tvm runtime"; + stack_.func_call("TVMUtils::AttrFromArg", a->name + "_size") + .call_arg(DocUtils::ToIndex("args", "pos")) + .func_call("TVMUtils::AttrFromArgs") + .call_arg("args") + .call_arg("pos") + .call_arg(a->name + "_size") + .call_arg(DocUtils::ToAttrAccess("meta_attr", a->name)) + .assign("pos", "pos + 1 + " + a->name + "_size"); + } else { + stack_.func_call("TVMUtils::AttrFromArg") + .call_arg(DocUtils::ToIndex("args", "pos")) + .call_arg(DocUtils::ToAttrAccess("meta_attr", a->name)) + .assign("pos", "pos + 1"); + } + } + stack_.func_end("meta_attr"); +} + +void TVMPluginCodeGen::CodeGenOpDeclare(const Plugin& plugin) { + // infer struct info + stack_.func_def("InferStructInfo" + plugin->name, "Array"); + for (const auto& t : plugin->inputs) { + stack_.func_arg(t->name, "const Expr&"); + } + for (const auto& a : plugin->attrs) { + const String& anno = IsListType(a->type) ? "Tuple" : "PrimValue"; + stack_.func_arg(a->name, "const " + anno + "&"); + } + // infer layout + stack_.func_def("InferLayout" + plugin->name, "InferLayoutOutput") + .func_arg("inputs", "const Array&") + .func_arg("var_layout_map", "const VarLayoutMap&"); +} + +void TVMPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { + const auto& attr_name = MetaAttrCls(plugin); + // infer struct info + Array infer_args{"input_metas", "meta_attr", "false"}; + stack_.func_def("InferStructInfo" + plugin->name, "Array"); + for (const auto& t : plugin->inputs) { + stack_.func_arg(t->name, "const Expr&"); + } + for (const auto& a : plugin->attrs) { + const String& anno = IsListType(a->type) ? "Tuple" : "PrimValue"; + stack_.func_arg(a->name, "const " + anno + "&"); + } + stack_.func_start() + .comment("extract meta attrs") + .func_call(attr_name + "_from_exprs", DocUtils::ToDeclare("const auto&", "meta_attr")); + for (const auto& a : plugin->attrs) { + stack_.call_arg(a->name); + } + stack_.comment("extract meta inputs").declare("std::vector", "input_metas"); + for (const auto& t : plugin->inputs) { + stack_.func_call("push_back", "", "input_metas") + .inplace_start("TVMUtils::ToMetaTensor") + .call_arg(t->name) + .inplace_end(); + } + stack_.declare("std::vector", "output_metas"); + CodeGenSafeCall(plugin->externs["infer_output"], infer_args, "output_metas"); + stack_.declare("Array", "output_sinfo"); + for (size_t i = 0; i < plugin->outputs.size(); i++) { + stack_.func_call("push_back", "", "output_sinfo") + .inplace_start("TVMUtils::ToTensorStructInfo") + .call_arg(DocUtils::ToIndex("output_metas", i)); + int device_idx = plugin->FindDeviceRefIdx(plugin->outputs[i]); + if (device_idx >= 0) { + stack_.call_arg(plugin->inputs[device_idx]->name); + } else { + stack_.inplace_start("TVMUtils::ToTVMDevice") + .call_arg(plugin->outputs[i]->device) + .inplace_end(); + } + stack_.inplace_end(); + } + stack_.func_end("output_sinfo"); + + // infer layout + stack_.func_def("InferLayout" + plugin->name, "InferLayoutOutput") + .func_arg("inputs", "const Array&") + .func_arg("var_layout_map", "const VarLayoutMap&") + .func_start() + .comment("define attrs"); + for (size_t i = 0; i < plugin->attrs.size(); i++) { + const auto& attr = plugin->attrs[i]; + const String& anno = IsListType(attr->type) ? "Tuple" : "PrimValue"; + stack_ + .func_call("Downcast<" + anno + ">", + DocUtils::ToDeclare("const auto&", "attr_" + attr->name)) + .call_arg(DocUtils::ToIndex("inputs", i + plugin->inputs.size())); + } + stack_.declare("Array", "arg_layouts") + .declare("Array", "output_layouts") + .comment("extract meta attrs") + .func_call(attr_name + "_from_exprs", "const " + attr_name + "& meta_attr"); + for (const auto& a : plugin->attrs) { + stack_.call_arg("attr_" + a->name); + } + stack_.comment("extract meta inputs") + .declare("std::vector", "input_metas") + .for_start("i", 0, plugin->inputs.size()) + .func_call("LayoutUtils::InferLayoutDecision", + DocUtils::ToDeclare("const auto&", "in_layout")) + .call_arg(DocUtils::ToIndex("inputs", "i")) + .call_arg("var_layout_map") + .func_call("push_back", "", "arg_layouts") + .call_arg("in_layout") + .func_call("push_back", "", "input_metas") + .inplace_start("TVMUtils::ToMetaTensor") + .call_arg(DocUtils::ToIndex("inputs", "i")) + .call_arg("in_layout") + .inplace_end() + .for_end() + .comment("add fake layout for attrs") + .for_start("i", 0, plugin->attrs.size()) + .func_call("push_back", "", "arg_layouts") + .inplace_start("LayoutDecision") + .call_arg(DocUtils::ToStr("")) + .inplace_end() + .for_end(); + stack_.declare("std::vector", "output_metas"); + CodeGenSafeCall(plugin->externs["infer_output"], infer_args, "output_metas"); + stack_.for_start("i", 0, plugin->outputs.size()) + .func_call("push_back", "", "output_layouts") + .inplace_start("LayoutDecision") + .call_arg(DocUtils::ToAttrAccess(DocUtils::ToIndex("output_metas", "i"), "layout_name()")) + .inplace_end() + .for_end() + .declare("Array", "input_layouts") + .func_call("push_back", "", "input_layouts") + .inplace_start("LayoutDecision") + .call_arg(DocUtils::ToStr("")) + .inplace_end() + .func_call("push_back", "", "input_layouts") + .call_arg("arg_layouts") + .func_call("InferLayoutOutput", DocUtils::ToDeclare("const auto&", "infer_output")) + .call_arg("input_layouts") + .call_arg("output_layouts") + .call_arg("Attrs()"); + stack_.func_end("infer_output"); + + // register funcs + stack_.func_call("TVM_REGISTER_GLOBAL") + .call_arg(DocUtils::ToStr("msc.plugin.op.InferStructInfo" + plugin->name)) + .method_call("set_body_typed") + .call_arg("InferStructInfo" + plugin->name) + .line() + .func_call("TVM_REGISTER_GLOBAL") + .call_arg(DocUtils::ToStr("msc.plugin.op.InferLayout" + plugin->name)) + .method_call("set_body_typed") + .call_arg("InferLayout" + plugin->name) + .line(); +} + +void TVMPluginCodeGen::CodeGenOpRuntime(const Plugin& plugin) { + ICHECK(!plugin->externs.count("infer_buffer")) << "infer_buffer is not supported for tvm runtime"; + const auto& attr_name = MetaAttrCls(plugin); + const auto& func_name = ComputeName(plugin); + String device_cond = ""; + for (size_t i = 0; i < plugin->inputs.size(); i++) { + String device_type = ""; + if (plugin->inputs[i]->device == "cuda" || plugin->inputs[i]->device == "default") { + device_type = "DLDeviceType::kDLCUDA"; + } else { + device_type = "DLDeviceType::kDLCPU"; + } + device_cond = device_cond + "TVMUtils::OnDevice(" + plugin->inputs[i]->name + ", " + + device_type + ")" + (i == plugin->inputs.size() - 1 ? "" : " && "); + } + stack_.func_def(func_name).func_arg("args", "TVMArgs").func_arg("ret", "TVMRetValue*"); + stack_.func_start().comment("define tensors"); + for (size_t i = 0; i < plugin->inputs.size(); i++) { + stack_.assign(plugin->inputs[i]->name, DocUtils::ToIndex("args", i), "DLTensor*"); + } + stack_.comment("extract meta attrs") + .assign("pos", plugin->inputs.size(), "size_t") + .func_call(attr_name + "_from_args", "const " + attr_name + "& meta_attr") + .call_arg("args") + .call_arg("pos"); + for (size_t i = 0; i < plugin->outputs.size(); i++) { + stack_.assign(plugin->outputs[i]->name, DocUtils::ToIndex("args", "pos + " + std::to_string(i)), + "DLTensor*"); + } + stack_.comment("do the compute").cond_if(device_cond); + CodeGenCompute(plugin, "cuda"); + stack_.cond_else(); + CodeGenCompute(plugin, "cpu"); + stack_.cond_end().func_end(); + // register the compute + stack_.func_call("TVM_REGISTER_GLOBAL") + .call_arg(DocUtils::ToStr(plugin->name)) + .method_call("set_body") + .call_arg(func_name) + .line(); +} + +void TVMPluginCodeGen::CodeGenCmake(const std::set& devices) { + Map flags; + flags.Set("PLUGIN_SUPPORT_TVM", ""); + CodeGenPreCmake(devices, flags); + stack_.line("set(CMAKE_CXX_STANDARD 17)") + .line("set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -Wno-macro-redefined\")") + .line() + .line("set(TVM_ROOT " + config()->tvm_root + ")") + .line("find_library(TVM_LIB NAMES tvm HINTS ${TVM_ROOT}/build NO_DEFAULT_PATH)"); + Array includes, libs; + includes.push_back("${TVM_ROOT}/include"); + includes.push_back("${TVM_ROOT}/3rdparty/dmlc-core/include"); + includes.push_back("${TVM_ROOT}/3rdparty/dlpack/include"); + includes.push_back("${TVM_ROOT}/3rdparty/compiler-rt"); + libs.push_back("${TVM_LIB}"); + CodeGenPostCmake(devices, includes, libs); +} + +void TVMPluginCodeGen::CodeGenManagerDepends() { + BasePluginCodeGen::CodeGenManagerDepends(); + stack_.line("from tvm import relax") + .line("from tvm.relax import call_dps_packed") + .line("from tvm.contrib.msc.plugin import utils as plugin_utils") + .line("from tvm.contrib.msc.core import utils as msc_utils") + .line(); +} + +void TVMPluginCodeGen::CodeGenManagerMethods() { + BasePluginCodeGen::CodeGenManagerMethods(); + stack_.func_def("setup") + .func_arg("self", "object") + .func_start() + .for_start("lib", "os.listdir(self._lib_folder)") + .assign("lib_file", "os.path.join(self._lib_folder, lib)") + .func_call("CDLL", "", "ctypes") + .call_arg("lib_file") + .for_end() + .line("from tvm.contrib.msc.plugin.op import _ffi_api") + .assign(DocUtils::ToAttrAccess("self", "_ffi_api"), "_ffi_api") + .func_end(); +} + +void TVMPluginCodeGen::CodeGenOpBuilder(const Plugin& plugin) { + stack_.func_def(plugin->name).func_arg("self", "object"); + for (const auto& t : plugin->inputs) { + stack_.func_arg(t->name, "relax.Expr"); + } + for (const auto& attr : plugin->attrs) { + stack_.func_arg(attr->name, ToPyType(attr->type), attr->default_value); + } + stack_.func_arg("name", "str", "\"" + plugin->name + "\"").func_start(); + Array args; + for (const auto& t : plugin->inputs) { + args.push_back(t->name); + } + for (const auto& a : plugin->attrs) { + stack_.func_call("plugin_utils.to_expr", a->name).call_arg(a->name); + args.push_back(a->name); + } + stack_.func_call("relax.Tuple", "args") + .call_arg(DocUtils::ToList(args)) + .func_call("InferStructInfo" + plugin->name, "out_sinfo", "self._ffi_api"); + for (const auto& t : plugin->inputs) { + stack_.call_arg(t->name); + } + for (const auto& a : plugin->attrs) { + stack_.call_arg(a->name); + } + stack_.func_call("call_dps_packed", "op") + .call_arg(DocUtils::ToStr(plugin->name)) + .call_arg("args", "args") + .call_arg("list(out_sinfo)", "out_sinfo") + .func_call("msc_utils.set_expr_name", "op") + .call_arg("op") + .call_arg("name"); + stack_.func_end("op").comment(GetPyComment(plugin), true); +} + +void TVMPluginCodeGen::CodeGenCompute(const Plugin& plugin, const String& device) { + if (plugin->externs.count(device + "_compute")) { + // compute with dtype + auto prepare_tensor = [this](const PluginTensor& tensor, const Map& dtypes, + size_t idx, const String& collect) { + const String& t_name = "d_" + tensor->name; + const String& t_dtype = dtypes.count(tensor->name) ? dtypes[tensor->name] : tensor->dtype; + const String& tensor_type = "DataTensor<" + t_dtype + ">"; + const String& anno = collect == "input" ? "const " + tensor_type + "&" : tensor_type; + stack_.func_call("TVMUtils::To" + tensor_type, DocUtils::ToDeclare(anno, t_name)) + .call_arg(tensor->name) + .call_arg(collect == "input"); + return t_name; + }; + for (const auto& dtypes : GetDtypeMatrix(plugin)) { + const auto& tensor_dtypes = GetTensorDtypes(plugin, dtypes); + Array compute_args; + String dtype_cond = ""; + for (size_t i = 0; i < plugin->inputs.size(); i++) { + const auto& t_name = plugin->inputs[i]->name; + dtype_cond = dtype_cond + "TVMUtils::ToMetaType(" + t_name + + "->dtype) == DataUtils::ToMetaType(\"" + dtypes.at(i) + "\")"; + dtype_cond = dtype_cond + (i == plugin->inputs.size() - 1 ? "" : " && "); + } + // prepare compute datas + stack_.cond_if(dtype_cond).comment("prepare compute datas"); + for (size_t i = 0; i < plugin->inputs.size(); i++) { + const String& t_name = prepare_tensor(plugin->inputs[i], tensor_dtypes, i, "input"); + compute_args.push_back(t_name); + } + for (size_t i = 0; i < plugin->outputs.size(); i++) { + const String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "output"); + compute_args.push_back(t_name); + } + ICHECK(plugin->buffers.size() == 0) << "Plugin with buffers is not supported in tvm"; + compute_args.push_back("meta_attr"); + if (device == "cuda") { + stack_.assign("stream", "runtime::CUDAThreadEntry::ThreadLocal()->stream", "auto"); + compute_args.push_back("stream"); + } + CodeGenSafeCall(plugin->externs[device + "_compute"], compute_args); + stack_.cond_end(); + } + } else { + stack_.comment("Skip compute on " + device); + } +} + +TVM_REGISTER_GLOBAL("msc.plugin.GetTVMPluginSources") + .set_body_typed([](const String& codegen_config, const String& print_config, + const String& codegen_type) -> Map { + TVMPluginCodeGen codegen = TVMPluginCodeGen(codegen_config); + if (codegen_type == "build") { + return codegen.GetBuildSources(print_config); + } + if (codegen_type == "manager") { + return codegen.GetManagerSources(print_config); + } + return Map(); + }); + +} // namespace msc +} // namespace contrib +} // namespace tvm diff --git a/src/contrib/msc/plugin/tvm_codegen.h b/src/contrib/msc/plugin/tvm_codegen.h new file mode 100644 index 000000000000..520e35de95c6 --- /dev/null +++ b/src/contrib/msc/plugin/tvm_codegen.h @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/plugin/tvm_codegen.h + * \brief Codegen for tvm plugin. + */ +#ifndef TVM_CONTRIB_MSC_PLUGIN_TVM_CODEGEN_H_ +#define TVM_CONTRIB_MSC_PLUGIN_TVM_CODEGEN_H_ + +#include +#include + +#include "base_codegen.h" +#include "codegen_utils.h" + +namespace tvm { +namespace contrib { +namespace msc { + +/*! + * \brief CodeGen config for tvm plugin + */ +struct TVMPluginCodeGenConfig { + bool as_relay{false}; + std::string tvm_root{"tvm"}; + PLUGIN_CODEGEN_CONFIG_MEMBERS + void Load(dmlc::JSONReader* reader) { + std::string key; + reader->BeginObject(); + while (reader->NextObjectItem(&key)) { + if (key == "as_relay") { + reader->Read(&as_relay); + } else if (key == "tvm_root") { + reader->Read(&tvm_root); + } else { + PLUGIN_CODEGEN_CONFIG_PARSE + } + } + } +}; + +class TVMPluginCodeGen : public BasePluginCodeGen { + public: + /*! + * \brief The constructor of TVMPluginCodeGen + * \param config the options for codegen. + */ + explicit TVMPluginCodeGen(const std::string& config = "") + : BasePluginCodeGen(config) {} + + protected: + /*! \brief Codegen plugin attr declare*/ + void CodeGenAttrDeclare(const Plugin& plugin) final; + + /*! \brief Codegen plugin attr define*/ + void CodeGenAttrDefine(const Plugin& plugin) final; + + /*! \brief Codegen plugin op declare*/ + void CodeGenOpDeclare(const Plugin& plugin) final; + + /*! \brief Codegen plugin op define*/ + void CodeGenOpDefine(const Plugin& plugin) final; + + /*! \brief Codegen plugin runtime*/ + void CodeGenOpRuntime(const Plugin& plugin) final; + + /*! \brief Codegen cmake file*/ + void CodeGenCmake(const std::set& devices) final; + + /*! \brief Codegen manager depends*/ + void CodeGenManagerDepends() final; + + /*! \brief Codegen manager methods*/ + void CodeGenManagerMethods() final; + + /*! \brief Codegen manager member for plugin*/ + void CodeGenOpBuilder(const Plugin& plugin) final; + + private: + /*! \brief Func name of compute*/ + const String ComputeName(const Plugin& plugin) { return plugin->name + "_compute"; } + + /*! \brief Codegen compute*/ + void CodeGenCompute(const Plugin& plugin, const String& device); + + /*! \brief Type name in tvm*/ + const String ToTVMType(const String& type) { + if (type == "string") { + return "StringImm"; + } + if (StringUtils::StartsWith(type, "float")) { + return "FloatImm"; + } + if (type == "bool" || StringUtils::StartsWith(type, "int")) { + return "IntImm"; + } + if (IsListType(type)) { + return "Tuple"; + } + return BasePluginCodeGen::ToCppType(type); + } +}; + +} // namespace msc +} // namespace contrib +} // namespace tvm +#endif // TVM_CONTRIB_MSC_PLUGIN_TVM_CODEGEN_H_ diff --git a/tests/python/contrib/test_msc/test_plugin.py b/tests/python/contrib/test_msc/test_plugin.py new file mode 100644 index 000000000000..277268f8aee8 --- /dev/null +++ b/tests/python/contrib/test_msc/test_plugin.py @@ -0,0 +1,309 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Test Plugin in MSC. """ + +import numpy as np + +import torch +from torch import nn + +import tvm.testing +from tvm import relax +from tvm.relax.transform import BindParams +from tvm.script import relax as R +from tvm.contrib.msc.plugin import build_plugins +from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils + + +def _get_externs_header(): + """Get the header source for externs""" + + return """#ifndef EXTERNS_H_ +#define EXTERNS_H_ + +#include "plugin_base.h" + +#ifdef PLUGIN_ENABLE_CUDA +#include +#endif + +namespace tvm { +namespace contrib { +namespace msc { +namespace plugin { + +template +std::vector my_relu_infer(const std::vector& inputs, const TAttr& attrs, + bool is_runtime) { + std::vector outputs; + outputs.push_back(MetaTensor(inputs[0].shape(), inputs[0].data_type(), inputs[0].layout())); + return outputs; +} + +template +void my_relu_cpu_kernel(const DataTensor& input, DataTensor& output, T max_val); + +template +void my_relu_cpu_compute(const DataTensor& input, DataTensor& output, const TAttr& attrs) { + my_relu_cpu_kernel(input, output, T(attrs.max_val)); +} + +#ifdef PLUGIN_ENABLE_CUDA +template +void my_relu_cuda_kernel(const DataTensor& input, DataTensor& output, T max_val, + const cudaStream_t& stream); + +template +void my_relu_cuda_compute(const DataTensor& input, DataTensor& output, const TAttr& attrs, + const cudaStream_t& stream) { + my_relu_cuda_kernel(input, output, T(attrs.max_val), stream); +} +#endif + +} // namespace plugin +} // namespace msc +} // namespace contrib +} // namespace tvm +#endif // EXTERNS_H_ +""" + + +def _get_externs_cc(): + """Get externs cc source""" + return """#include "externs.h" + +namespace tvm { +namespace contrib { +namespace msc { +namespace plugin { + +template +void my_relu_cpu_kernel(const DataTensor& input, DataTensor& output, T max_val) { + const T* input_data = input.const_data(); + T* output_data = output.data(); + for (size_t i = 0; i < output.size(); i++) { + if (input_data[i] >= max_val) { + output_data[i] = max_val; + } else if (input_data[i] <= 0) { + output_data[i] = 0; + } else { + output_data[i] = input_data[i]; + } + } +} + +template void my_relu_cpu_kernel(const DataTensor& input, DataTensor& output, + float max_val); + +} // namespace plugin +} // namespace msc +} // namespace contrib +} // namespace tvm +""" + + +def _get_externs_cu(): + """Get externs cu source""" + + return """#include "externs.h" + +#define CU1DBLOCK 256 +#define KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) + +namespace tvm { +namespace contrib { +namespace msc { +namespace plugin { + +inline int n_blocks(int size, int block_size) { + return size / block_size + (size % block_size == 0 ? 0 : 1); +} + +template +__global__ static void _my_relu(const T* src, T* dst, T max_val, int n) { + KERNEL_LOOP(i, n) { + if (src[i] >= max_val) { + dst[i] = max_val; + } else if (src[i] <= 0) { + dst[i] = 0; + } else { + dst[i] = src[i]; + } + } +} + +template +void my_relu_cuda_kernel(const DataTensor& input, DataTensor& output, T max_val, + const cudaStream_t& stream) { + const T* input_data = input.const_data(); + T* output_data = output.data(); + dim3 Bl(CU1DBLOCK); + dim3 Gr(n_blocks(output.size(), CU1DBLOCK)); + _my_relu<<>>(input_data, output_data, max_val, output.size()); +} + +template void my_relu_cuda_kernel(const DataTensor& input, DataTensor& output, + float max_val, const cudaStream_t& stream); + +} // namespace plugin +} // namespace msc +} // namespace contrib +} // namespace tvm +""" + + +def _create_plugin(externs_dir): + """Create sources under source folder""" + with open(externs_dir.relpath("externs.h"), "w") as f: + f.write(_get_externs_header()) + with open(externs_dir.relpath("externs.cc"), "w") as f: + f.write(_get_externs_cc()) + with open(externs_dir.relpath("externs.cu"), "w") as f: + f.write(_get_externs_cu()) + return { + "MyRelu": { + "inputs": [{"name": "input", "dtype": "T"}], + "outputs": [{"name": "output", "dtype": "T"}], + "attrs": [{"name": "max_val", "type": "float"}], + "support_dtypes": {"T": ["float"]}, + "externs": { + "infer_output": {"name": "my_relu_infer", "header": "externs.h"}, + "cpu_compute": { + "name": "my_relu_cpu_compute", + "header": "externs.h", + "source": "externs.cc", + }, + "cuda_compute": { + "name": "my_relu_cuda_compute", + "header": "externs.h", + "source": "externs.cu", + }, + }, + } + } + + +def _get_torch_model(torch_manager): + """Build model with plugin""" + + class MyModel(nn.Module): + """Test model with plugin""" + + def __init__(self): + super(MyModel, self).__init__() + self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) + self.relu = torch_manager.MyRelu(max_val=0.5) + self.maxpool = nn.MaxPool2d(kernel_size=[1, 1]) + + def forward(self, data): + data = self.conv(data) + data = self.relu(data) + return self.maxpool(data) + + return MyModel() + + +def _get_tvm_model(tvm_manager): + """Build model with plugin""" + + block_builder = relax.BlockBuilder() + weights = np.random.rand(6, 3, 7, 7).astype("float32") + data = relax.Var("data", R.Tensor((1, 3, 224, 224), "float32")) + weight = relax.Var("weight", R.Tensor(weights.shape, weights.dtype.name)) + inputs = [data, weight] + with block_builder.function(name="main", params=inputs.copy()): + with block_builder.dataflow(): + data = relax.op.nn.conv2d(data, weight) + data = block_builder.emit(data, "conv2d") + data = tvm_manager.MyRelu(data, max_val=0.5) + data = block_builder.emit(data, "relu") + data = relax.op.nn.max_pool2d(data) + data = block_builder.emit(data, "max_pool2d") + data = block_builder.emit_output(data) + block_builder.emit_func_output(data) + mod = block_builder.finalize() + return BindParams("main", {"weight": tvm.nd.array(weights)})(mod) + + +def _build_plugin(frameworks, plugin_root): + externs_dir = plugin_root.create_dir("externs") + install_dir = plugin_root.create_dir("install") + plugin = _create_plugin(externs_dir) + managers = build_plugins(plugin, frameworks, install_dir, externs_dir=externs_dir) + return managers + + +def _run_relax(relax_mod, target_name, data): + target = tvm.target.Target(target_name) + relax_mod = tvm.relax.transform.LegalizeOps()(relax_mod) + if target_name == "cuda": + with target: + relax_mod = tvm.tir.transform.DefaultGPUSchedule()(relax_mod) + device = tvm.cuda() + else: + device = tvm.cpu() + with tvm.transform.PassContext(opt_level=3): + relax_exec = tvm.relax.build(relax_mod, target) + runnable = tvm.relax.VirtualMachine(relax_exec, device) + data = tvm.nd.array(data, device) + return runnable["main"](data).asnumpy() + + +def _test_tvm_plugin(manager, target): + """Test plugin in tvm""" + + model = _get_tvm_model(manager) + data = np.random.rand(1, 3, 224, 224).astype("float32") + outputs = _run_relax(model, target, data) + assert outputs.min() >= 0 and outputs.max() <= 0.5 + + +def _test_torch_plugin(manager): + """Test plugin in torch""" + + model = _get_torch_model(manager) + torch_data = torch.from_numpy(np.random.rand(1, 3, 224, 224).astype("float32")) + if torch.cuda.is_available(): + model = model.to(torch.device("cuda:0")) + torch_data = torch_data.to(torch.device("cuda:0")) + outputs = model(torch_data) + assert outputs.min() >= 0 and outputs.max() <= 0.5 + + +def test_plugin(): + """Test the plugins""" + + frameworks = [MSCFramework.TORCH, MSCFramework.TVM] + if tvm.get_global_func("relax.ext.tensorrt", True) is not None: + frameworks.append(MSCFramework.TENSORRT) + plugin_root = msc_utils.msc_dir("msc_plugin") + managers = _build_plugin(frameworks, plugin_root) + + # test the plugin load + _test_tvm_plugin(managers[MSCFramework.TVM], "llvm") + if tvm.cuda().exist: + _test_tvm_plugin(managers[MSCFramework.TVM], "cuda") + _test_torch_plugin(managers[MSCFramework.TORCH]) + + plugin_root.destory() + + +if __name__ == "__main__": + tvm.testing.main()