diff --git a/python/tvm/contrib/msc/core/runtime/runner.py b/python/tvm/contrib/msc/core/runtime/runner.py index 2849eb05ed83..6d3a364e90ec 100644 --- a/python/tvm/contrib/msc/core/runtime/runner.py +++ b/python/tvm/contrib/msc/core/runtime/runner.py @@ -57,6 +57,8 @@ class BaseRunner(object): Whether compile model to trainable stage: str The stage of runner. + plugin: PluginManager + The plugin manager. name: str The name of the runner debug_level: int @@ -75,6 +77,7 @@ def __init__( device: str = "cpu", training: bool = False, stage: str = "default", + plugin: Any = None, name: str = "main", debug_level: int = 0, logger: logging.Logger = None, @@ -86,6 +89,7 @@ def __init__( self._build_config = msc_utils.copy_dict(build_config) self._device = device if self._device_enabled(device) else "cpu" self._stage = stage + self._plugin = plugin self._name = name self._debug_level = debug_level self._training, self._trained = training, training @@ -123,8 +127,11 @@ def setup(self) -> dict: stage=self._stage, **config, ) + if self._plugin: + self._update_codegen({"use_plugin": True}) return { "tools": {k: v.tool_style() for k, v in self._tools.items()}, + "plugin": self._plugin, "translate_config": self._translate_config, "generate_config": self._generate_config, "build_config": self._build_config, @@ -1069,6 +1076,7 @@ def _generate_model(self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.arra codegen_config=self._generate_config.get("codegen"), print_config=self._generate_config.get("print"), build_folder=self._generate_config["build_folder"], + plugin=self._plugin, ) def _inspect_model(self) -> dict: @@ -1226,6 +1234,7 @@ def _generate_model(self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.arra extra_options=extra_option, build_folder=self._generate_config["build_folder"], output_folder=self._generate_config.get("output_folder", msc_utils.get_output_dir()), + plugin=self._plugin, ) def _build_runnable(self, model: Any) -> Any: diff --git a/python/tvm/contrib/msc/core/tools/prune/pruner.py b/python/tvm/contrib/msc/core/tools/prune/pruner.py index bb2ff9922073..7eb4434a62f3 100644 --- a/python/tvm/contrib/msc/core/tools/prune/pruner.py +++ b/python/tvm/contrib/msc/core/tools/prune/pruner.py @@ -82,6 +82,7 @@ def _parse_strategys(self, strategy_list: dict) -> Dict[str, ToolStrategy]: def _update_stages(strategy): if "stages" not in strategy: strategy["stages"] = [msc_utils.MSCStage.PRUNE] + strategy["tensor_types"] = ["weight", "output"] return strategy return super()._parse_strategys([_update_stages(s) for s in strategy_list]) diff --git a/python/tvm/contrib/msc/core/tools/quantize/quantizer.py b/python/tvm/contrib/msc/core/tools/quantize/quantizer.py index bee8e6fa42eb..3b0f3267df85 100644 --- a/python/tvm/contrib/msc/core/tools/quantize/quantizer.py +++ b/python/tvm/contrib/msc/core/tools/quantize/quantizer.py @@ -114,6 +114,11 @@ def _check_tensor(self, name: str, consumer: str) -> bool: Whether to process the tensor. """ + if self._calibrated: + tensor_id = self.to_tensor_id(name, consumer) + if tensor_id not in self._plan: + return False + return self._plan.get(tensor_id, {}).get("nbits", 8) != -1 strategys = self._get_tensor_strategys(name, consumer) if not strategys: return False diff --git a/python/tvm/contrib/msc/core/tools/tool.py b/python/tvm/contrib/msc/core/tools/tool.py index 7253841122ae..fec391339f20 100644 --- a/python/tvm/contrib/msc/core/tools/tool.py +++ b/python/tvm/contrib/msc/core/tools/tool.py @@ -409,7 +409,7 @@ def _parse_strategys(self, strategy_list: List[dict]) -> Dict[str, ToolStrategy] tensor_names = strategy.pop("tensor_names") marks = [(n, "tensor") for n in tensor_names] else: - marks = [("default", t) for t in ["input", "output", "weight"]] + marks = [("default." + str(t), t) for t in tensor_types] stages = strategy.pop("stages") if "stages" in strategy else ["default"] for mark, t_type in marks: if mark not in strategys: @@ -1212,33 +1212,38 @@ def _get_tensor_strategys(self, name: str, consumer: str) -> List[ToolStrategy]: tensor_id = self.to_tensor_id(name, consumer) mark = "strategy.{}".format(self._stage) + + def _check_strategy(s_ref): + return s_ref in self._strategys and self._strategys[s_ref].support_stage(self._stage) + if mark not in self._tensor_cache.get(tensor_id, {}): - if self.is_weight(name): + strategys = [] + tensor_strategy = self._strategys.get(tensor_id) + if tensor_strategy and tensor_strategy.support_stage(self._stage): + strategys.append(tensor_strategy) + elif self.is_weight(name): consumer = self.find_node(consumer) - name_refs = [consumer.name + ".weight", consumer.optype + ".weight"] + for ref in [consumer.name, consumer.optype, "default"]: + if _check_strategy(ref + ".weight"): + strategys.append(self._strategys[ref + ".weight"]) + break elif consumer == "exit": producer = self.find_producer(name) - name_refs = [producer.name + ".output", producer.optype + ".output"] + for ref in [producer.name, producer.optype, "exit", "default"]: + if _check_strategy(ref + ".output"): + strategys.append(self._strategys[ref + ".output"]) + break else: consumer = self.find_node(consumer) + for ref in [consumer.name, consumer.optype, "default"]: + if _check_strategy(ref + ".input"): + strategys.append(self._strategys[ref + ".input"]) + break producer = self.find_producer(name) - name_refs = [ - producer.name + ".output", - producer.optype + ".output", - consumer.name + ".input", - consumer.optype + ".input", - ] - strategys = [] - tensor_strategy = self._strategys.get(tensor_id) - if tensor_strategy and tensor_strategy.support_stage(self._stage): - strategys.append(tensor_strategy) - if not strategys: - for n in name_refs: - if n in self._strategys and self._strategys[n].support_stage(self._stage): - strategys.append(self._strategys[n]) - d_strategy = self._strategys.get("default") - if not strategys and d_strategy and d_strategy.support_stage(self._stage): - strategys.append(d_strategy) + for ref in [producer.name, producer.optype, "default"]: + if _check_strategy(ref + ".output"): + strategys.append(self._strategys[ref + ".output"]) + break self._save_tensor_cache(name, consumer, mark, strategys) return self._get_tensor_cache(name, consumer, mark) diff --git a/python/tvm/contrib/msc/framework/tensorflow/codegen/codegen.py b/python/tvm/contrib/msc/framework/tensorflow/codegen/codegen.py index 4555d23528ca..f24150efcd6c 100644 --- a/python/tvm/contrib/msc/framework/tensorflow/codegen/codegen.py +++ b/python/tvm/contrib/msc/framework/tensorflow/codegen/codegen.py @@ -16,7 +16,7 @@ # under the License. """tvm.contrib.msc.framework.tensorflow.codegen.codegen""" -from typing import Dict, Optional +from typing import Dict, Optional, Any import tvm from tvm.contrib.msc.core.ir import MSCGraph @@ -32,6 +32,7 @@ def to_tensorflow( codegen_config: Optional[Dict[str, str]] = None, print_config: Optional[Dict[str, str]] = None, build_folder: msc_utils.MSCDirectory = None, + plugin: Any = None, ) -> tf_v1.Graph: """Change MSCGraph to tensorflow graph. @@ -47,6 +48,8 @@ def to_tensorflow( The config for print. build_folder: MSCDirectory The folder for saving scripts and datas. + plugin: PluginManager + The plugin manager. Returns ------- @@ -63,4 +66,7 @@ def _save_weights(folder: msc_utils.MSCDirectory): codegen = CodeGen( graph, _ffi_api.GetTensorflowSources, codegen_config, print_config, build_folder ) - return codegen.load(inputs + [weights], pre_load=_save_weights) + model_args = inputs + [weights] + if plugin: + model_args = model_args + [plugin] + return codegen.load(model_args, pre_load=_save_weights) diff --git a/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py b/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py index d72b14cfd53e..4643d49c1e83 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py +++ b/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py @@ -18,7 +18,7 @@ import os import subprocess -from typing import Dict, Optional, List, Union +from typing import Dict, Optional, List, Union, Any import numpy as np import tvm @@ -38,6 +38,7 @@ def to_sub_tensorrt( print_config: Optional[Dict[str, str]] = None, build_folder: msc_utils.MSCDirectory = None, output_folder: msc_utils.MSCDirectory = None, + plugin: Any = None, ) -> str: """Change MSCGraph to TensorRT engine file. @@ -55,6 +56,8 @@ def to_sub_tensorrt( The folder for saving sources and datas. export_folder: MSCDirectory The folder for saving outputs. + plugin: PluginManager + The plugin manager. Returns ------- @@ -90,6 +93,10 @@ def _create_depends(folder: msc_utils.MSCDirectory) -> str: f.write("{}\n".format(len(engine_wts))) for name, data in engine_wts.items(): write_weight(name, msc_utils.cast_array(data), f) + # copy plugin + if plugin: + plugin.copy_libs("plugin_lib") + plugin.copy_includes("plugin") # save utils sources with folder.create_dir("utils") as utils_folder: for name, source in get_trt_sources().items(): @@ -115,6 +122,10 @@ def _build_engine(engine_name: str, folder: msc_utils.MSCDirectory) -> str: with build_folder as folder: sub_folder = folder.create_dir(graph.name) + if plugin: + codegen_config["extern_libs"] = [ + sub_folder.create_dir("plugin_lib").relpath(f) for f in plugin.list_libs() + ] codegen = CodeGen( graph, _ffi_api.GetTensorRTSources, @@ -140,6 +151,7 @@ def to_tensorrt( extra_options: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None, build_folder: msc_utils.MSCDirectory = None, output_folder: msc_utils.MSCDirectory = None, + plugin: Any = None, ) -> Dict[str, str]: """Change all MSCGraphs to TensorRT engine files. @@ -161,6 +173,8 @@ def to_tensorrt( The folder for saving sources and datas. export_folder: MSCDirectory The folder for saving outputs. + plugin: PluginManager + The plugin manager. Returns ------- @@ -183,6 +197,7 @@ def to_tensorrt( print_configs[idx], build_folder, output_folder, + plugin=plugin, ) if extra_options[idx]: options.update(extra_options[idx]) diff --git a/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py b/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py index effa86595dff..8eea3f7081a7 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py +++ b/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py @@ -20,11 +20,13 @@ from typing import Mapping, Tuple, List, Union, Callable, Dict from functools import wraps, partial +import tvm from tvm import relax from tvm.relax.dpl import pattern from tvm.relax.transform import PatternCheckContext, FusionPattern from tvm.relax.backend.pattern_registry import register_patterns from tvm.contrib.msc.core.transform import pattern as msc_pattern +from tvm.contrib.msc.core import _ffi_api def basic_pattern( @@ -234,6 +236,43 @@ def _take_check(context: PatternCheckContext) -> bool: return _check_expr(context.annotated_expr["input_1"], ("int32")) +def _plugin_check(context: PatternCheckContext) -> bool: + """Check if the plugin pattern is correct. + + Returns + ------- + pass: bool + Whether the pattern is correct. + """ + + ext_func = context.annotated_expr["out"].args[0] + return bool(_ffi_api.IsPlugin(ext_func.global_symbol)) + + +def plugin_attrs_getter( + annotated_expr: Dict[str, tvm.relax.Expr], +) -> Dict[str, str]: + """Get attributes for plugin pattern + + Parameters + ---------- + annotated_expr: dict + The annotated exprs during fus pattern + anchor: str + The anchor key of expr + + Returns + ------- + attrs: dict + The extra attributes for msc. + """ + + attrs = msc_pattern.msc_attrs_getter(annotated_expr, anchor="out") + ext_func = annotated_expr["out"].args[0] + attrs[_ffi_api.ToAttrKey("optype")] = ext_func.global_symbol + return attrs + + def wrap_basic_check( func: Callable[[PatternCheckContext], bool] ) -> Callable[[PatternCheckContext], bool]: @@ -410,6 +449,15 @@ def get_patterns(target) -> List[Pattern]: ), ] ) + # plugin ops + patterns.append( + ( + target + ".plugin", + *basic_pattern("relax.call_dps_packed", ["input", "input"]), + _plugin_check, + plugin_attrs_getter, + ) + ) return patterns diff --git a/python/tvm/contrib/msc/framework/torch/codegen/codegen.py b/python/tvm/contrib/msc/framework/torch/codegen/codegen.py index f885c81aa652..d4aeabb10a1b 100644 --- a/python/tvm/contrib/msc/framework/torch/codegen/codegen.py +++ b/python/tvm/contrib/msc/framework/torch/codegen/codegen.py @@ -16,7 +16,7 @@ # under the License. """tvm.contrib.msc.framework.torch.codegen.codegen""" -from typing import Dict, Optional +from typing import Dict, Optional, Any import torch import tvm @@ -32,6 +32,7 @@ def to_torch( codegen_config: Optional[Dict[str, str]] = None, print_config: Optional[Dict[str, str]] = None, build_folder: msc_utils.MSCDirectory = None, + plugin: Any = None, ) -> torch.nn.Module: """Change MSCGraph to torch nn.Module. @@ -47,6 +48,8 @@ def to_torch( The config for print. build_folder: MSCDirectory The folder for saving scripts and datas. + plugin: PluginManager + The plugin manager. Returns ------- @@ -73,4 +76,5 @@ def _bind_weights(model: torch.nn.Module, folder: msc_utils.MSCDirectory) -> tor return model codegen = CodeGen(graph, _ffi_api.GetTorchSources, codegen_config, print_config, build_folder) - return codegen.load([], pre_load=_save_weights, post_load=_bind_weights) + model_args = [plugin] if plugin else [] + return codegen.load(model_args, pre_load=_save_weights, post_load=_bind_weights) diff --git a/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py b/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py index c344b9260644..4038b74b7ea2 100644 --- a/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py +++ b/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py @@ -16,7 +16,7 @@ # under the License. """tvm.contrib.msc.framework.tvm.codegen.codegen""" -from typing import Dict, Optional +from typing import Dict, Optional, Any import tvm from tvm.relax.transform import BindParams @@ -32,6 +32,7 @@ def to_relax( codegen_config: Optional[Dict[str, str]] = None, print_config: Optional[Dict[str, str]] = None, build_folder: msc_utils.MSCDirectory = None, + plugin: Any = None, ) -> tvm.IRModule: """Change MSCGraph to IRModule. @@ -47,6 +48,8 @@ def to_relax( The config for print. build_folder: MSCDirectory The folder for saving scripts and datas. + plugin: PluginManager + The plugin manager. Returns ------- @@ -81,4 +84,7 @@ def _post_proc(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRModul )(mod) codegen = CodeGen(graph, _ffi_api.GetRelaxSources, codegen_config, print_config, build_folder) - return codegen.load(inputs, pre_load=_save_weights, post_load=_post_proc) + model_args = inputs + if plugin: + model_args = model_args + [plugin] + return codegen.load(model_args, pre_load=_save_weights, post_load=_post_proc) diff --git a/python/tvm/contrib/msc/pipeline/manager.py b/python/tvm/contrib/msc/pipeline/manager.py index a8327a08cde3..42ef227b551b 100644 --- a/python/tvm/contrib/msc/pipeline/manager.py +++ b/python/tvm/contrib/msc/pipeline/manager.py @@ -32,6 +32,8 @@ from tvm.contrib.msc.core.utils.message import MSCStage from tvm.contrib.msc.core import utils as msc_utils from tvm.contrib.msc.core.gym.control import create_controller +from tvm.contrib.msc.core import _ffi_api +from tvm.contrib.msc.plugin.utils import load_plugins class BaseManager(object): @@ -43,9 +45,25 @@ class BaseManager(object): The raw model in framwork. config: dict The config for pipeline. + plugins: dict + The plugins for pipeline. + root: str + The root path for files. """ - def __init__(self, model: Any, config: dict): + def __init__(self, model: Any, config: dict, plugins: dict = None, root: str = None): + # change path to root path + if root: + + def _from_root_mark(val): + if root and isinstance(val, str) and MSCKey.ROOT_MARK in val: + return val.replace(MSCKey.ROOT_MARK, root) + return val + + model = _from_root_mark(model) + config = msc_utils.map_dict(config, _from_root_mark) + plugins = msc_utils.map_dict(plugins, _from_root_mark) + # check stage for stage in ["inputs", "outputs", "dataset", MSCStage.PREPARE, MSCStage.COMPILE]: assert stage in config, "{} should be given to run the pipeline".format(stage) @@ -55,6 +73,10 @@ def __init__(self, model: Any, config: dict): self._model, self._device, self._training = self._get_runner_cls( self._model_type ).load_native(model) + if plugins: + self._plugins = load_plugins(plugins) + else: + self._plugins = {} use_cache = config.get("use_cache", True) self._workspace = msc_utils.set_workspace(config.get("workspace"), use_cache) self._verbose = config.get("verbose", "info") @@ -87,6 +109,12 @@ def setup(self, config: dict) -> dict: self._meta_config = config self._optimize_type = config.get(MSCStage.OPTIMIZE, {}).get("run_type", self._model_type) self._compile_type = config.get(MSCStage.COMPILE, {}).get("run_type", self._model_type) + # register plugins + if self._plugins: + for t in [self._model_type, self._optimize_type, self._compile_type]: + assert t in self._plugins, "Missing plugin for {}".format(t) + for name, plugin in self._plugins[self._model_type].get_ops_info().items(): + _ffi_api.RegisterPlugin(name, msc_utils.dump_dict(plugin)) self._config, self._debug_levels = self.update_config(config) self._tools_config = {} self._relax_mod, self._runner = None, None @@ -100,7 +128,7 @@ def setup(self, config: dict) -> dict: "duration": {}, "profile": {}, } - return {"workspace": self._workspace.path, "config": config} + return {"workspace": self._workspace.path, "plugins": self._plugins, "config": config} def update_config(self, config: dict) -> dict: """Update config @@ -300,20 +328,22 @@ def parse(self) -> tvm.IRModule: self._logger.info("Load parsed mod from %s", cache_path) else: parse_config = msc_utils.copy_dict(stage_config.get("parse_config", {})) - runner_cls = self._get_runner_cls(self._config[MSCStage.COMPILE]["run_type"]) - trans_func = ( - runner_cls.target_transform if hasattr(runner_cls, "target_transform") else None - ) - parse_info = { - "parser": stage_config["parser"], - "config": parse_config, - "trans_func": trans_func, - } + parse_info = {"parser": stage_config["parser"], "config": parse_config} self._logger.info(msc_utils.msg_block("PARSE", parse_info)) parse_config["as_msc"] = False + if self._model_type in self._plugins: + plugin = self._plugins[self._model_type] + parse_config["custom_convert_map"] = plugin.get_convert_map() self._relax_mod, _ = stage_config["parser"](self._model, **parse_config) - if trans_func: - self._relax_mod = trans_func(self._relax_mod) + for stage in [MSCStage.OPTIMIZE, MSCStage.COMPILE]: + if stage not in self._config: + continue + runner_cls = self._get_runner_cls(self._config[stage]["run_type"]) + if hasattr(runner_cls, "target_transform"): + self._logger.info( + "Transform for stage %s: %s", stage, runner_cls.target_transform + ) + self._relax_mod = runner_cls.target_transform(self._relax_mod) self._relax_mod = msc_transform.SetExprName()(self._relax_mod) if cache_path: with open(cache_path, "w") as f: @@ -498,6 +528,7 @@ def _create_runner( runner = runner_cls( self._relax_mod, tools_config=tools_config, + plugin=self._plugins.get(stage_config["run_type"]), stage=stage, logger=self._logger, **run_config, @@ -534,6 +565,10 @@ def _apply_tool(self, tool_type: str, stage_config: dict, add_tool: bool = True) assert tool_type in stage_config, "Can not find config for tool " + str(tool_type) tool_stage, tool_config = self._get_tool_stage(tool_type), stage_config[tool_type] + if "run_type" in tool_config: + run_type = tool_config.pop("run_type") + else: + run_type = stage_config["run_type"] plan_file = tool_config["plan_file"] if "gym_configs" in tool_config: gym_configs = tool_config.pop("gym_configs") @@ -548,10 +583,7 @@ def _apply_tool(self, tool_type: str, stage_config: dict, add_tool: bool = True) self._logger.info("Skip %s with plan %s", tool_type, plan_file) return plan_file msc_utils.time_stamp(tool_stage) - t_stage_config = { - "run_type": stage_config["run_type"], - "run_config": stage_config["run_config"], - } + t_stage_config = {"run_type": run_type, "run_config": stage_config["run_config"]} runner = self._create_runner( tool_stage, t_stage_config, tools_config=tools_config, profile=False, use_cache=False ) diff --git a/src/contrib/msc/core/codegen/base_codegen.h b/src/contrib/msc/core/codegen/base_codegen.h index 0b5756df7e63..19d8b524b9e2 100644 --- a/src/contrib/msc/core/codegen/base_codegen.h +++ b/src/contrib/msc/core/codegen/base_codegen.h @@ -33,6 +33,7 @@ #include #include "../ir/graph.h" +#include "../ir/plugin.h" #include "code_stack.h" #include "codegen_utils.h" @@ -81,6 +82,36 @@ class BaseOpCode { return IdxWeightBase(node_, wtype, process); } + /*! \brief Get the node attr as doc*/ + const ExprDoc GetAttrDoc(const String& key, const String& type) { + if (StringUtils::StartsWith(type, "list")) { + const String& ele_type = + StringUtils::Replace(StringUtils::Replace(type, "list(", ""), ")", ""); + if (ele_type == "bool") { + return DocUtils::ToList(node_->GetTypeArrayAttr(key)); + } else if (ele_type == "int" || ele_type == "int32") { + return DocUtils::ToList(node_->GetTypeArrayAttr(key)); + } else if (ele_type == "long" || ele_type == "int64") { + return DocUtils::ToList(node_->GetTypeArrayAttr(key)); + } else if (ele_type == "float" || ele_type == "float32") { + return DocUtils::ToList(node_->GetTypeArrayAttr(key)); + } else if (ele_type == "string") { + return DocUtils::ToStrList(node_->GetTypeArrayAttr(key)); + } + } else if (type == "bool") { + return DocUtils::ToDoc(node_->GetTypeAttr(key)); + } else if (type == "int" || type == "int32") { + return DocUtils::ToDoc(node_->GetTypeAttr(key)); + } else if (type == "long" || type == "int64") { + return DocUtils::ToDoc(node_->GetTypeAttr(key)); + } else if (type == "float" || type == "float32") { + return DocUtils::ToDoc(node_->GetTypeAttr(key)); + } else if (type == "string") { + return DocUtils::ToStr(node_->GetTypeAttr(key)); + } + return DocUtils::ToDoc(node_->GetTypeAttr(key)); + } + /*! \brief Get comment for default node*/ const String Comment() { return Comment(node_); } @@ -169,6 +200,14 @@ class BaseCodeGen { } } + /*! \brief Get the optype for op codegen*/ + const String GetOpType(const MSCJoint& node) { + if (config_->use_plugin && IsPlugin(node->optype)) { + return "plugin"; + } + return node->optype; + } + /*! \brief Get the docs for the op*/ virtual const Array GetOpCodes(const MSCJoint& node) = 0; diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc b/src/contrib/msc/core/transform/set_expr_layout.cc index affb653a230d..0ece7a51cac8 100644 --- a/src/contrib/msc/core/transform/set_expr_layout.cc +++ b/src/contrib/msc/core/transform/set_expr_layout.cc @@ -509,6 +509,21 @@ InferLayoutOutput ForwardInferLayoutTake(const Call& call, return InferLayoutOutput({LayoutDecision("WE"), input_layout}, {output_layout}, Attrs()); } +InferLayoutOutput ForwardInferLayoutPlugin(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + if (!call->args[0]->IsInstance()) { + return InferLayoutOutput(); + } + const auto& name = Downcast(call->args[0])->global_symbol; + const auto* pf = runtime::Registry::Get("msc.plugin.op.InferLayout" + name); + if (pf == nullptr) { + return InferLayoutOutput(); + } + const auto& args = Downcast(call->args[1]); + return (*pf)(args->fields, var_layout_map); +} + TVM_REGISTER_OP("relax.nn.conv1d") .set_attr("FMSCForwardInferLayout", MSCInferLayoutConv); TVM_REGISTER_OP("relax.nn.conv2d") @@ -603,6 +618,10 @@ TVM_REGISTER_OP("relax.nn.group_norm") TVM_REGISTER_OP("relax.nn.layer_norm") .set_attr("FMSCForwardInferLayout", ForwardInferLayoutNormalize); +// plugin op +TVM_REGISTER_OP("relax.call_dps_packed") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutPlugin); + // Backward Infer InferLayoutOutput BackwardInferLayoutCommon(const Call& call, const Map>& desired_layouts, diff --git a/src/contrib/msc/framework/tensorrt/codegen.cc b/src/contrib/msc/framework/tensorrt/codegen.cc index 13c231092a8f..717eb75e1f36 100644 --- a/src/contrib/msc/framework/tensorrt/codegen.cc +++ b/src/contrib/msc/framework/tensorrt/codegen.cc @@ -27,6 +27,8 @@ #include #include +#include + #include "../../core/codegen/codegen_json.h" namespace tvm { @@ -43,6 +45,17 @@ void TensorRTCodeGen::CodeGenClassDeclare() { if (config()->precision == "int8") { stack_.line("#include \"utils/trt_quantize.h\""); } + // plugin headers + if (config()->use_plugin) { + std::set plugins; + for (const auto& n : graph()->node_names) { + const auto& node = graph()->FindNode(n); + if (IsPlugin(node->optype) && !plugins.count(node->optype)) { + stack_.line("#include \"plugin/" + node->optype + "_op.h\""); + plugins.insert(node->optype); + } + } + } stack_.line().line("using namespace nvinfer1;").line(); StartNamespace(); // start class declare @@ -439,6 +452,7 @@ void TensorRTCodeGen::CodeGenCmake() { stack_.line("cmake_minimum_required(VERSION " + config()->cmake_version + " FATAL_ERROR)") .line("project(" + graph()->name + ")") .line("find_package(CUDA)") + .line() .line("find_path(TRT_INCLUDE_DIR NvInfer.h HINTS " + config()->tensorrt_root + " PATH_SUFFIXES include)") .line("find_library(TRT_LIBS nvinfer HINTS " + config()->tensorrt_root + @@ -447,13 +461,23 @@ void TensorRTCodeGen::CodeGenCmake() { "message(STATUS \"Build project with TRT_INCLUDE_DIR ${TRT_INCLUDE_DIR} and " "TRT_LIBS " "${TRT_LIBS}\")") + .line() .line("add_definitions(-DTRT_MAJOR=" + std::to_string(config()->version[0]) + ")") .line("add_definitions(-DTRT_MINOR=" + std::to_string(config()->version[1]) + ")") .line("add_definitions(-DTRT_PATCH=" + std::to_string(config()->version[2]) + ")") - .line("file(GLOB_RECURSE TRT_SRCS *.cc)") + .line(); + if (config()->use_plugin) { + stack_.line("add_definitions(-DPLUGIN_SUPPORT_TENSORRT)").line(); + } + String link_libs = " ${TRT_LIBS}"; + if (config()->extern_libs.size() > 0) { + stack_.line("set(EXTERN_LIBS " + StringUtils::Join(config()->extern_libs, " ") + ")"); + link_libs = link_libs + " ${EXTERN_LIBS}"; + } + stack_.line("file(GLOB_RECURSE TRT_SRCS *.cc)") .line("cuda_add_executable(" + graph()->name + " ${TRT_SRCS})") .line("target_include_directories(" + graph()->name + " PUBLIC ${TRT_INCLUDE_DIR})") - .line("target_link_libraries(" + graph()->name + " ${TRT_LIBS})"); + .line("target_link_libraries(" + graph()->name + link_libs + ")"); } const String TensorRTCodeGen::IdxTensor(const MSCTensor& tensor) { @@ -518,7 +542,7 @@ const String TensorRTCodeGen::ToDims(const Array& dims, bool use_ndim) const Array TensorRTCodeGen::GetOpCodes(const MSCJoint& node) { const auto& ops_map = GetTensorRTOpCodes(); - auto it = ops_map->find(node->optype); + auto it = ops_map->find(GetOpType(node)); ICHECK(it != ops_map->end()) << "Unsupported tensorrt op(" << node->optype << "): " << node; it->second->Config(node, config()); try { diff --git a/src/contrib/msc/framework/tensorrt/codegen_utils.h b/src/contrib/msc/framework/tensorrt/codegen_utils.h index bfaecb8d3dc8..f006b21b816e 100644 --- a/src/contrib/msc/framework/tensorrt/codegen_utils.h +++ b/src/contrib/msc/framework/tensorrt/codegen_utils.h @@ -25,6 +25,7 @@ #define TVM_CONTRIB_MSC_FRAMEWORK_TENSORRT_CODEGEN_UTILS_H_ #include +#include #include "../../core/codegen/base_codegen.h" #include "../../core/codegen/codegen_utils.h" @@ -89,6 +90,7 @@ struct TensorRTCodeGenConfig { std::string precision{"float32"}; std::string precision_mode{"strict"}; std::string tensorrt_root{"/usr/local/cuda"}; + std::vector extern_libs; CODEGEN_CONFIG_MEMBERS void Load(dmlc::JSONReader* reader) { std::string key; @@ -114,6 +116,8 @@ struct TensorRTCodeGenConfig { reader->Read(&precision_mode); } else if (key == "tensorrt_root") { reader->Read(&tensorrt_root); + } else if (key == "extern_libs") { + reader->Read(&extern_libs); } else { CODEGEN_CONFIG_PARSE } diff --git a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc index 2cacd11907ff..a080fdd77862 100644 --- a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc +++ b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc @@ -73,15 +73,15 @@ const String TensorRTOpCode::DeclareInputs(bool simplify) { const auto& idx_input = StringUtils::Replace(IdxInput(), "*", ""); stack_.declare("std::vector", inputs_ref + "_vec") .declare_arg(node()->inputs.size()) - .declare_arg(idx_input) - .assign(inputs_ref, inputs_ref + "_vec.data()", "ITensor**"); + .declare_arg(idx_input); } else { - stack_.declare("std::vector", IdxNode(), 0, false); + stack_.declare("std::vector", inputs_ref + "_vec", 0, false); for (size_t i = 0; i < node()->inputs.size(); i++) { const auto& idx_input = StringUtils::Replace(IdxInput(i), "*", ""); stack_.declare_arg(idx_input); } } + stack_.assign(inputs_ref, inputs_ref + "_vec.data()", "ITensor**"); return inputs_ref; } @@ -298,10 +298,7 @@ class TensorRTConcatCodeGen : public TensorRTOpCode { const auto& producer = node()->ProducerOf(0); ICHECK(node()->parents.size() == 1 && producer->optype == "tuple") << "Concat expect parent as tuple, get " << node(); - stack_.op_call() - .inplace_start("data", "", IdxNodeBase(producer)) - .inplace_end() - .call_arg(producer->inputs.size()); + stack_.op_call().call_arg(IdxNodeBase(producer)).call_arg(producer->inputs.size()); SetLayerByValue("Axis", AttrToAxis()); } }; @@ -679,11 +676,8 @@ class TensorRTTupleCodeGen : public TensorRTOpCode { protected: void CodeGenBuild() final { - stack_.declare("std::vector", IdxNode(), 0, false); - for (size_t i = 0; i < node()->inputs.size(); i++) { - const auto& idx_input = StringUtils::Replace(IdxInput(i), "*", ""); - stack_.declare_arg(idx_input); - } + const auto& inputs_ref = DeclareInputs(); + stack_.assign(IdxNode(), inputs_ref, "auto"); } }; @@ -710,6 +704,35 @@ class TensorRTWhereCodeGen : public TensorRTOpCode { void CodeGenBuild() final { stack_.op_call().op_inputs_arg(false); } }; +class TensorRTPluginOpCodeGen : public TensorRTOpCode { + public: + TENSORRT_OP_CODEGEN_METHODS(TensorRTPluginOpCodeGen) + + protected: + void CodeGenBuild() final { + const auto& producer = node()->ParentAt(0); + ICHECK(producer->optype == "tuple") + << "Only support tensorrt plugin with tuple, get " << producer; + + const auto& plugin = GetPlugin(node()->optype); + const auto& input_ref = "inputs_" + std::to_string(producer->index); + const String& func_name = "plugin::" + node()->optype + "DynamicPlugin"; + const String& plugin_ref = "plugin_" + std::to_string(node()->index); + const String& layouts_ref = "layouts_" + std::to_string(node()->index); + stack_.declare("std::vector", layouts_ref, 0, false); + for (const auto& i : node()->GetInputs()) { + stack_.declare_arg(DocUtils::ToStr(i->layout.name())); + } + stack_.func_call(func_name, DocUtils::ToDeclare("auto", plugin_ref)) + .call_arg(DocUtils::ToStr(node()->name)); + for (const auto& a : plugin->attrs) { + stack_.call_arg(GetAttrDoc(a->name, a->type)); + } + stack_.call_arg(layouts_ref); + stack_.op_call().call_arg(input_ref).call_arg(plugin->inputs.size()).call_arg(plugin_ref); + } +}; + const std::shared_ptr>> GetTensorRTOpCodes() { static auto map = std::make_shared>>(); @@ -796,16 +819,15 @@ GetTensorRTOpCodes() { // special op map->emplace("input", std::make_shared("Input")); + map->emplace("get_item", std::make_shared("")); + map->emplace("tuple", std::make_shared("")); + map->emplace("plugin", std::make_shared("PluginV2")); // msc ops map->emplace("msc.conv2d_bias", std::make_shared("ConvolutionNd", true)); map->emplace("msc.linear", std::make_shared("FullyConnected", false)); map->emplace("msc.linear_bias", std::make_shared("FullyConnected", true)); - // special op - map->emplace("get_item", std::make_shared("")); - map->emplace("tuple", std::make_shared("")); - return map; } diff --git a/src/contrib/msc/framework/torch/codegen.cc b/src/contrib/msc/framework/torch/codegen.cc index 012f0311b26d..54859ad0ce89 100644 --- a/src/contrib/msc/framework/torch/codegen.cc +++ b/src/contrib/msc/framework/torch/codegen.cc @@ -44,8 +44,14 @@ void TorchCodeGen::CodeGenGraph() { stack_.func_decorator("msc_tools.wrap_step(\"build\",\"" + config()->tools_tag + "\")"); } stack_.func_arg("self", "torch.nn.Module"); - stack_.func_start(); - stack_.func_call("super").call_arg(graph()->name).call_arg("self").method_call("__init__"); + if (config()->use_plugin) { + stack_.func_arg("plugin", "Any"); + } + stack_.func_start() + .func_call("super") + .call_arg(graph()->name) + .call_arg("self") + .method_call("__init__"); for (const auto& n : graph()->node_names) { const auto& node = graph()->FindNode(n); if (node->optype == "input") { @@ -99,9 +105,17 @@ void TorchCodeGen::CodeGenGraph() { } void TorchCodeGen::CodeGenInference() { - stack_.comment("Build Model") - .func_call(graph()->name, "model") - .comment("Load weights") + if (config()->use_plugin) { + stack_.comment("Import Plugin") + .line("from msc_plugin.torch import PluginManager") + .line() + .func_call("PluginManager", "plugin"); + } + stack_.comment("Build Model").func_call(graph()->name, "model"); + if (config()->use_plugin) { + stack_.call_arg("plugin"); + } + stack_.comment("Load weights") .func_call("torch.load", "weights") .call_arg(DocUtils::ToStr(graph()->name + ".pth")) .func_call("load_state_dict", "", "model") @@ -126,7 +140,7 @@ void TorchCodeGen::CodeGenInference() { const Array TorchCodeGen::GetOpCodes(const MSCJoint& node) { const auto& ops_map = GetTorchOpCodes(); - auto it = ops_map->find(node->optype); + auto it = ops_map->find(GetOpType(node)); ICHECK(it != ops_map->end()) << "Unsupported torch op(" << node->optype << "): " << node; it->second->Config(node, config(), is_init_); try { diff --git a/src/contrib/msc/framework/torch/torch_opcode.cc b/src/contrib/msc/framework/torch/torch_opcode.cc index 5086678758f2..59d30e774000 100644 --- a/src/contrib/msc/framework/torch/torch_opcode.cc +++ b/src/contrib/msc/framework/torch/torch_opcode.cc @@ -223,7 +223,13 @@ class TorchConstantCodeGen : public TorchOpCode { } } - void CodeGenForward() final { stack_.assign(IdxNode(), module_ref()); } + void CodeGenForward() final { + if (config()->use_tools) { + stack_.assign(IdxNode(), IdxWeight("const", true)); + } else { + stack_.assign(IdxNode(), module_ref()); + } + } }; class TorchConvCodeGen : public TorchOpCode { @@ -510,7 +516,7 @@ class TorchReshapeCodeGen : public TorchOpCode { const auto& out_layout = node()->OutputAt(0)->layout; if (out_layout.defined()) { int32_t batch_dim = out_layout.IndexOf(tvm::tir::LayoutAxis::Get("N")); - if (batch_dim > 0) { + if (batch_dim >= 0) { shape.Set(batch_dim, Integer(-1)); } } @@ -608,6 +614,21 @@ class TorchTupleCodeGen : public TorchOpCode { void CodeGenForward() final { stack_.op_call().op_inputs_arg(); } }; +class TorchPluginOpCodeGen : public TorchOpCode { + TORCH_OP_CODEGEN_METHODS(TorchPluginOpCodeGen) + + protected: + void CodeGenInit() final { + const auto& plugin = GetPlugin(node()->optype); + stack_.op_call("plugin." + node()->optype); + for (const auto& a : plugin->attrs) { + stack_.call_arg(GetAttrDoc(a->name, a->type), a->name); + } + } + + void CodeGenForward() final { stack_.op_call().op_inputs_arg(false); } +}; + const std::shared_ptr>> GetTorchOpCodes() { static auto map = std::make_shared>>(); if (!map->empty()) return map; @@ -728,6 +749,7 @@ const std::shared_ptr>> map->emplace("get_item", std::make_shared("", "")); map->emplace("shape", std::make_shared("", "torch.Size")); map->emplace("tuple", std::make_shared("", "tuple")); + map->emplace("plugin", std::make_shared("Plugin", "")); // msc ops map->emplace("msc.attention", std::make_shared( @@ -743,7 +765,6 @@ const std::shared_ptr>> std::make_shared("nn.Linear", "functional.linear", false)); map->emplace("msc.linear_bias", std::make_shared("nn.Linear", "functional.linear", true)); - return map; } diff --git a/src/contrib/msc/framework/tvm/codegen.cc b/src/contrib/msc/framework/tvm/codegen.cc index 20c47d929125..783551eed35b 100644 --- a/src/contrib/msc/framework/tvm/codegen.cc +++ b/src/contrib/msc/framework/tvm/codegen.cc @@ -40,6 +40,9 @@ void RelaxCodeGen::CodeGenGraph() { stack_.func_arg(idx_input, "relax.Var"); idx_inputs.push_back(idx_input); } + if (config()->use_plugin) { + stack_.func_arg("plugin", "Any"); + } stack_.func_start().assign("inputs", DocUtils::ToList(idx_inputs, true)); // define weights stack_.comment("Define the weights"); @@ -123,6 +126,12 @@ void RelaxCodeGen::CodeGenGraph() { } void RelaxCodeGen::CodeGenInference() { + if (config()->use_plugin) { + stack_.comment("Import Plugin") + .line("from msc_plugin.tvm import PluginManager") + .line() + .func_call("PluginManager", "plugin"); + } for (const auto& i : graph()->GetInputs()) { const auto& producer = graph()->FindProducer(i); stack_.func_call("relax.Var", IdxNodeBase(producer)) @@ -133,6 +142,9 @@ void RelaxCodeGen::CodeGenInference() { .pop_nest(); } stack_.comment("Build Module").func_call(graph()->name, "mod"); + if (config()->use_plugin) { + stack_.call_arg("plugin"); + } for (const auto& i : graph()->GetInputs()) { const auto& producer = graph()->FindProducer(i); stack_.call_arg(IdxNodeBase(producer)); @@ -177,7 +189,7 @@ void RelaxCodeGen::CodeGenInference() { const Array RelaxCodeGen::GetOpCodes(const MSCJoint& node) { const auto& ops_map = GetRelaxOpCodes(); - auto it = ops_map->find(node->optype); + auto it = ops_map->find(GetOpType(node)); ICHECK(it != ops_map->end()) << "Unsupported relax op(" << node->optype << "): " << node; it->second->Config(node, config()); try { diff --git a/src/contrib/msc/framework/tvm/relax_opcode.cc b/src/contrib/msc/framework/tvm/relax_opcode.cc index 16b78193ae6a..0b7ef6aa825e 100644 --- a/src/contrib/msc/framework/tvm/relax_opcode.cc +++ b/src/contrib/msc/framework/tvm/relax_opcode.cc @@ -668,6 +668,19 @@ class RelaxTriCodeGen : public RelaxOpCode { } }; +class RelaxPluginOpCodeGen : public RelaxOpCode { + RELAX_OP_CODEGEN_METHODS(RelaxPluginOpCodeGen) + + protected: + void CodeGenBuild() final { + const auto& plugin = GetPlugin(node()->optype); + stack_.op_call("plugin." + node()->optype).op_inputs_arg(false); + for (const auto& a : plugin->attrs) { + stack_.call_arg(GetAttrDoc(a->name, a->type), a->name); + } + } +}; + const std::shared_ptr>> GetRelaxOpCodes() { static auto map = std::make_shared>>(); if (!map->empty()) return map; @@ -798,6 +811,7 @@ const std::shared_ptr>> map->emplace("get_item", std::make_shared("relax.TupleGetItem")); map->emplace("shape", std::make_shared("relax.ShapeExpr")); map->emplace("tuple", std::make_shared("relax.Tuple")); + map->emplace("plugin", std::make_shared("Plugin")); // msc ops map->emplace("msc.attention", std::make_shared("relax.op.nn.attention"));